Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

BUG: numpy function take strips units with scalar index input #549

Closed
kyleaoman opened this issue Dec 19, 2024 · 3 comments · Fixed by #551
Closed

BUG: numpy function take strips units with scalar index input #549

kyleaoman opened this issue Dec 19, 2024 · 3 comments · Fixed by #551
Labels
bug Something isn't working

Comments

@kyleaoman
Copy link
Contributor

  • unyt version: 3.0.3
  • Python version: 3.12
  • Operating System: Linux (Fedora 40)

Description

The numpy function take strips units when a scalar index (rather than array-like) is passed.

What I Did

>>> a = [1, 2, 3] * u.cm
>>> a.take([0])  # OK
unyt_array([1], 'cm')
>>> np.take(a, [0])  # OK
unyt_array([1], 'cm')
>>> a.take(0)  # not OK
np.int64(1)
>>> np.take(a, 0)  # not OK
np.int64(1)
@neutrinoceros
Copy link
Member

Ah, this is a case where I just wrote the simplest test case possible (test_take), saw that it worked, and went to the next function. So, we actually don't have an implementation for np.take at the moment.
A quick test reveals that implementing np.take won't automatically make unyt_array.take correct either, so I think we also need to override np.ndarray.take here. We could probably implement one using the other though. Do you want to give it a go ?

@kyleaoman
Copy link
Contributor Author

Yes possibly, but I'm now on xmas holiday, so maybe not soon :)

@neutrinoceros
Copy link
Member

No worries. Happy holidays !

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants