-
Notifications
You must be signed in to change notification settings - Fork 34
MAINT: simplify torch
dtype promotion
#303
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 1 out of 1 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (2)
array_api_compat/torch/_aliases.py:140
- The try/except block in _result_type silently swallows a KeyError when a dtype pair is not found in the promotion table. Consider handling the error explicitly or adding a clarifying comment to document the intended fallback behavior.
return _promotion_table[xdt, ydt]
array_api_compat/torch/_aliases.py:296
- Consider adding tests for _sum_prod_no_axis to verify that the dtype conversion and deep copy behavior work as expected, especially for edge cases involving small integer types like uint8.
return x.clone() if dtype == x.dtype else x.to(dtype)
(torch.complex64, torch.complex128): torch.complex128, | ||
(torch.complex128, torch.complex64): torch.complex128, | ||
(torch.complex128, torch.complex128): torch.complex128, | ||
# Mixed float and complex | ||
(torch.float32, torch.complex64): torch.complex64, | ||
(torch.float32, torch.complex128): torch.complex128, | ||
(torch.float64, torch.complex64): torch.complex128, | ||
(torch.float64, torch.complex128): torch.complex128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(complex, float)
use cases were missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @crusaderky
if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): | ||
return x.to(torch.int64) | ||
|
||
return x.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: this looks scary, but is in fact just a refactoring. Previously this stanza was duplicated in sum
and prod
.
Returning a copy looks reasonable, too.
Salvaged from #298
sum(x, axis=(), dtype=x.dtype)
andprod
with the same parameters, which were previously returningx
itself, to return a deep copy instead. While this is not part of the Array API, it's what both numpy an base torch do, and probably the most healthy thing to do xref https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html.Discussion here: Spell out where views are allowed array-api#921