Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Using default_dtypes instead of hard-coding dtypes. #666

Merged
merged 17 commits into from
Jan 27, 2025

Conversation

alxmrs
Copy link
Contributor

@alxmrs alxmrs commented Jan 13, 2025

Fixes #658.

@alxmrs alxmrs changed the title First pass for using default-dtypes instead of hard-coding. Using default_dtypes instead of hardcoding dtypes. Jan 13, 2025
@alxmrs alxmrs changed the title Using default_dtypes instead of hardcoding dtypes. Using default_dtypes instead of hard-coding dtypes. Jan 13, 2025
Copy link
Member

@tomwhite tomwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this @alxmrs! Added a few comments.

cubed/array_api/creation_functions.py Outdated Show resolved Hide resolved
cubed/array_api/inspection.py Outdated Show resolved Hide resolved
cubed/array_api/creation_functions.py Show resolved Hide resolved
@alxmrs
Copy link
Contributor Author

alxmrs commented Jan 14, 2025

Thanks for the early feedback, Tom. I'll TAL later today.

@tomwhite
Copy link
Member

This is looking good!

BTW I just noticed that we don't enforce linting in CI (I opened #668), so you might want to install pre-commit in your dev env and run pre-commit run --all-files.

@alxmrs
Copy link
Contributor Author

alxmrs commented Jan 15, 2025

Thanks for pointing out how to lint.

As a next step for this PR, I was going to add default_dtype logic to other places besides creation functions. Does that make sense? Would you recommend against this?

@tomwhite
Copy link
Member

As a next step for this PR, I was going to add default_dtype logic to other places besides creation functions.

That would be great!

@alxmrs
Copy link
Contributor Author

alxmrs commented Jan 15, 2025

I just found this useful dictionary! It's really good to know that these are defined somewhere we can access, I've been looking for just the thing w/in the array api!

_dtype_categories = {

@alxmrs
Copy link
Contributor Author

alxmrs commented Jan 15, 2025

Hey @tomwhite, I have a few questions for what types to use in my latest commit. For example, see here:
8ccb343#diff-36cd7306bc9a90e7b718d59616eaf4c80aa06e61e9f1d6262eecc37ad36a4ceeR37

On one hand, it makes sense to me that these should be default_dtypes, since certain types won't work on specific hardware. On the other hand, I don't want to truncate intermediary values.

In my next commit, I'm going to make these default dtypes. If that's not the way we should go, I'll just undo the commit.

@alxmrs
Copy link
Contributor Author

alxmrs commented Jan 15, 2025

(looks like the next next step is to add more test coverage 😬 .)

@tomwhite
Copy link
Member

On one hand, it makes sense to me that these should be default_dtypes, since certain types won't work on specific hardware. On the other hand, I don't want to truncate intermediary values.

In my next commit, I'm going to make these default dtypes. If that's not the way we should go, I'll just undo the commit.

That's a very good point. The choice of dtype for intermediate values is an implementation choice - and is not tied down by the spec. I'm inclined to not change this now in this PR, and look at it when looking at a concrete case for (say) JAX on GPUs. It would be good to get this merged as a relatively straightforward change - even if some cases need follow-ups.

Copy link
Member

@tomwhite tomwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few more comments. As I said in another comment I would leave the changes to statistical functions (pending further investigation of types for intermediate values) so it doesn't hold up the rest.

There are a couple of index dtypes in searchsorted and arg_reduction that would be easy to do though.

cubed/array_api/creation_functions.py Outdated Show resolved Hide resolved
cubed/array_api/statistical_functions.py Outdated Show resolved Hide resolved
cubed/array_api/creation_functions.py Outdated Show resolved Hide resolved
@alxmrs
Copy link
Contributor Author

alxmrs commented Jan 21, 2025

Hey @tomwhite -- I think I may be getting ahead of myself here. I was just adding default_dtypes to nan_functions.py now, but it occurs to me: these defaults seem to promote the precision, not reduce it. Is this intentional? If so, I think that's the last place where default dtypes may be needed.

For now, I'll add a commit with default_dtypes with the intention of removing it should it no longer be necessary.

Copy link
Member

@tomwhite tomwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting closer! I added a few more comments.

cubed/array_api/dtypes.py Show resolved Hide resolved
cubed/array_api/dtypes.py Outdated Show resolved Hide resolved
elif x.dtype == _complex_floating_dtypes:
dtype = dtypes["complex floating"]
elif x.dtype == _real_floating_dtypes:
dtype = dtypes["real floating"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove the two cases for complex and real floating, since sum and prod should not upcast in this case (see https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html#array_api.sum).

Also, nansum should be the same as sum. It's not in the spec, but NumPy behaves the same (version 2.1.1):

>>> np.nansum(np.array([1.0, 2.0], dtype=np.float32))
np.float32(3.0)

cubed/array_api/dtypes.py Outdated Show resolved Hide resolved
alxmrs and others added 2 commits January 24, 2025 15:01
Cross applied to nan_functions.py.
Co-authored-by: Tom White <tom.e.white@gmail.com>
@alxmrs alxmrs marked this pull request as ready for review January 25, 2025 00:00
Copy link
Member

@tomwhite tomwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for all your work on this @alxmrs!

@tomwhite tomwhite merged commit ce8f61b into cubed-dev:main Jan 27, 2025
16 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use default_dtypes to determine the output data type in array API implementations
2 participants