Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Arithmetic operations accept numpy arrays #102

Open
ev-br opened this issue Nov 27, 2024 · 3 comments · May be fixed by #115
Open

Arithmetic operations accept numpy arrays #102

ev-br opened this issue Nov 27, 2024 · 3 comments · May be fixed by #115

Comments

@ev-br
Copy link
Contributor

ev-br commented Nov 27, 2024

Supposedly, mixing array-api-strict arrays with other array types should not be allowed.

Or all of them should be allowed, but then we'd need to specify something like __array_priority__ and that opens quite a Pandora box, so I guess not?

In [5]: import numpy as np

In [6]: import array_api_strict as xp

In [7]: xp.arange(5, dtype=xp.int8) + np.arange(5, dtype=np.complex64)
Out[7]: array([0.+0.j, 2.+0.j, 4.+0.j, 6.+0.j, 8.+0.j], dtype=complex64)           # xp + np -> np !

In [8]: import torch

In [10]: xp.arange(5, dtype=xp.int8) + torch.arange(5)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 1
----> 1 xp.arange(5, dtype=xp.int8) + torch.arange(5)

TypeError: unsupported operand type(s) for +: 'Array' and 'Tensor'

In [11]: import jax.numpy as jnp

In [12]: xp.arange(5, dtype=xp.int8) + jnp.arange(5, dtype=jnp.complex64)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 1
----> 1 xp.arange(5, dtype=xp.int8) + jnp.arange(5, dtype=jnp.complex64)

TypeError: unsupported operand type(s) for +: 'Array' and 'jaxlib.xla_extension.ArrayImpl'
@ev-br
Copy link
Contributor Author

ev-br commented Nov 27, 2024

The offender is https://github.com/data-apis/array-api-strict/blob/main/array_api_strict/_array_object.py#L189 :

    def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
        ...
        if isinstance(other, (int, complex, float, bool)):
            ...
        elif isinstance(other, Array):
            ....
        else:
            return NotImplemented

and then, after __add__ returns NotImplemented , control flow ends in __array__ which happily calls np.asarray(self._array).

ISTM it's best to add an explicit type check to __op__(self, other) and explicitly whitelist arrays or scalars, so that we do not depend on #67.

EDIT: Alternatively, can probably attach the _allow_array flag to the Array object and force it to False in binops.

@asmeurer
Copy link
Member

Yes, ideally we would remove __array__, which would presumably fix this. But there were issues with that, which are described in the issue (the principle issue being that array-api-strict doesn't support the buffer protocol, as discussed in a recent consortium meeting).

@ev-br
Copy link
Contributor Author

ev-br commented Nov 27, 2024

Yes, it would just disappear if we could remove __array__. But realistically we cannot, not until from_dlpack matures anyway.

Re: buffer protocol, PEP 688 was mentioned in the consortium meeting, and indeed:

>>> class X:
...    def __init__(self, a):
...      self._a = a
...    def __buffer__(self, flags):
...       print('__buffer__')
...       return memoryview(self._a)
...    def __release_buffer(self, buffer):
...       print('__release__')
...       # what now? do we `del self._a` here?
... 
>>> import numpy as np
>>> x = X(np.arange(5))
>>> np.asarray(x)
__buffer__
array([0, 1, 2, 3, 4])

The main problem is, as also mentioned offline, is that it's new in python 3.12. Meaning downstream like SciPy will only be able to use it in a couple of years at best (cf NEP 29, https://numpy.org/neps/nep-0029-deprecation_policy.html).

So it seems that our options today are either

  • do nothing, wait for dlpack or __buffer__ support, whichever comes fist;
  • block what we can, as in BUG: reject ndarrays in binary operators #103 : block xp.array op np.array and accept that we can't do a thing about np.array op xp.array.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants