From 6d780a8ab542af6e09477416a2843f921878c5a4 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 15 Oct 2024 11:28:13 -0600 Subject: [PATCH] Fix issue with repeat() NumPy does not allow repeats to be uint64 because it refuses to downcast it. Technically it worked before because we implement __array__ and repeat does manually cast in that case. I'm not really sure we should be supporting __array__ actually. --- array_api_strict/_manipulation_functions.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 7652028..702d259 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -2,8 +2,8 @@ from ._array_object import Array from ._creation_functions import asarray -from ._data_type_functions import result_type -from ._dtypes import _integer_dtypes +from ._data_type_functions import astype, result_type +from ._dtypes import _integer_dtypes, int64, uint64 from ._flags import requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING @@ -94,7 +94,13 @@ def repeat( else: raise TypeError("repeats must be an int or array") - return Array._new(np.repeat(x._array, repeats, axis=axis)) + if repeats.dtype == uint64: + # NumPy does not allow uint64 because can't be cast down to x.dtype + # with 'safe' casting. However, repeats values larger than 2**63 are + # infeasable, and even if they are present by mistake, this will + # lead to underflow and an error. + repeats = astype(repeats, int64) + return Array._new(np.repeat(x._array, repeats._array, axis=axis)) # Note: the optional argument is called 'shape', not 'newshape' def reshape(x: Array,