Skip to content

MAINT: simplify torch dtype promotion #303

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 3 commits into from
Apr 15, 2025

Conversation

crusaderky
Copy link
Contributor

Salvaged from #298

@Copilot Copilot AI review requested due to automatic review settings April 10, 2025 10:01
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 1 out of 1 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (2)

array_api_compat/torch/_aliases.py:140

  • The try/except block in _result_type silently swallows a KeyError when a dtype pair is not found in the promotion table. Consider handling the error explicitly or adding a clarifying comment to document the intended fallback behavior.
return _promotion_table[xdt, ydt]

array_api_compat/torch/_aliases.py:296

  • Consider adding tests for _sum_prod_no_axis to verify that the dtype conversion and deep copy behavior work as expected, especially for edge cases involving small integer types like uint8.
return x.clone() if dtype == x.dtype else x.to(dtype)

(torch.complex64, torch.complex128): torch.complex128,
(torch.complex128, torch.complex64): torch.complex128,
(torch.complex128, torch.complex128): torch.complex128,
# Mixed float and complex
(torch.float32, torch.complex64): torch.complex64,
(torch.float32, torch.complex128): torch.complex128,
(torch.float64, torch.complex64): torch.complex128,
(torch.float64, torch.complex128): torch.complex128,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(complex, float) use cases were missing

Copy link
Member

@ev-br ev-br left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @crusaderky

if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
return x.to(torch.int64)

return x.clone()
Copy link
Member

@ev-br ev-br Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: this looks scary, but is in fact just a refactoring. Previously this stanza was duplicated in sum and prod.

Returning a copy looks reasonable, too.

@ev-br ev-br merged commit 9194c5c into data-apis:main Apr 15, 2025
40 checks passed
@ev-br ev-br added this to the 1.12 milestone Apr 15, 2025
@crusaderky crusaderky deleted the torch_promotion_table branch April 15, 2025 12:45
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants