-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
topk and argtopk #10086
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
topk and argtopk #10086
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -27,7 +27,12 @@ | |||||
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils | ||||||
from xarray.core.array_api_compat import get_array_namespace | ||||||
from xarray.core.options import OPTIONS | ||||||
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available | ||||||
from xarray.core.utils import ( | ||||||
is_duck_array, | ||||||
is_duck_dask_array, | ||||||
module_available, | ||||||
to_0d_object_array, | ||||||
) | ||||||
from xarray.namedarray.parallelcompat import get_chunked_array_type | ||||||
from xarray.namedarray.pycompat import array_type, is_chunked_array | ||||||
|
||||||
|
@@ -229,7 +234,7 @@ | |||||
xp = get_array_namespace(data) | ||||||
if xp == np: | ||||||
# numpy currently doesn't have a astype: | ||||||
return data.astype(dtype, **kwargs) | ||||||
Check warning on line 237 in xarray/core/duck_array_ops.py
|
||||||
return xp.astype(data, dtype, **kwargs) | ||||||
return data.astype(dtype, **kwargs) | ||||||
|
||||||
|
@@ -875,3 +880,74 @@ | |||||
|
||||||
def chunked_nanlast(darray, axis): | ||||||
return _chunked_first_or_last(darray, axis, op=nputils.nanlast) | ||||||
|
||||||
|
||||||
def argtopk(values, k, axis=None, skipna=None): | ||||||
if is_chunked_array(values): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
func = dask_array_ops.argtopk | ||||||
else: | ||||||
func = nputils.argtopk | ||||||
|
||||||
# Borrowed from nanops | ||||||
xp = get_array_namespace(values) | ||||||
if skipna or ( | ||||||
skipna is None | ||||||
and ( | ||||||
dtypes.isdtype(values.dtype, ("complex floating", "real floating"), xp=xp) | ||||||
or dtypes.is_object(values.dtype) | ||||||
) | ||||||
): | ||||||
valid_count = count(values, axis=axis) | ||||||
|
||||||
if k < 0: | ||||||
fill_value = dtypes.get_pos_infinity(values.dtype) | ||||||
else: | ||||||
fill_value = dtypes.get_neg_infinity(values.dtype) | ||||||
|
||||||
filled_values = fillna(values, fill_value) | ||||||
else: | ||||||
return func(values, k=k, axis=axis) | ||||||
|
||||||
data = func(filled_values, k=k, axis=axis) | ||||||
|
||||||
# TODO This will evaluate dask arrays and might be costly. | ||||||
if array_any(valid_count == 0): | ||||||
raise ValueError("All-NaN slice encountered") | ||||||
return data | ||||||
|
||||||
|
||||||
def topk(values, k, axis=None, skipna=None): | ||||||
if is_chunked_array(values): | ||||||
func = dask_array_ops.topk | ||||||
else: | ||||||
func = nputils.topk | ||||||
|
||||||
# Borrowed from nanops | ||||||
xp = get_array_namespace(values) | ||||||
if skipna or ( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the way to do this is to use the fact that |
||||||
skipna is None | ||||||
and ( | ||||||
dtypes.isdtype(values.dtype, ("complex floating", "real floating"), xp=xp) | ||||||
or dtypes.is_object(values.dtype) | ||||||
) | ||||||
): | ||||||
valid_count = count(values, axis=axis) | ||||||
|
||||||
if k < 0: | ||||||
fill_value = dtypes.get_pos_infinity(values.dtype) | ||||||
else: | ||||||
fill_value = dtypes.get_neg_infinity(values.dtype) | ||||||
|
||||||
filled_values = fillna(values, fill_value) | ||||||
else: | ||||||
return func(values, k=k, axis=axis) | ||||||
|
||||||
data = func(filled_values, k=k, axis=axis) | ||||||
|
||||||
if not hasattr(data, "dtype"): # scalar case | ||||||
data = fill_value if valid_count == 0 else data | ||||||
# we've computed a single min, max value of type object. | ||||||
# don't let np.array turn a tuple back into an array | ||||||
return to_0d_object_array(data) | ||||||
|
||||||
return where_method(data, valid_count != 0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -302,6 +302,59 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): | |
return coeffs, residuals | ||
|
||
|
||
def topk(values, k: int, axis: int): | ||
"""Extract the k largest elements from a on the given axis. | ||
If k is negative, extract the -k smallest elements instead. | ||
The returned elements are sorted. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would not sort. The user can do that if they really need to. |
||
""" | ||
if axis < 0: | ||
axis = values.ndim + axis | ||
|
||
if abs(k) >= values.shape[axis]: | ||
b = np.sort(values, axis=axis) | ||
else: | ||
a = np.partition(values, -k, axis=axis) | ||
k_slice = slice(-k, None) if k > 0 else slice(-k) | ||
b = a[tuple(k_slice if i == axis else slice(None) for i in range(values.ndim))] | ||
b.sort(axis=axis) | ||
if k < 0: | ||
return b | ||
return b[ | ||
tuple( | ||
slice(None, None, -1) if i == axis else slice(None) | ||
for i in range(values.ndim) | ||
) | ||
] | ||
|
||
|
||
def argtopk(values, k: int, axis: int): | ||
"""Extract the indices of the k largest elements from a on the given axis. | ||
If k is negative, extract the indices of the -k smallest elements instead. | ||
The returned elements are argsorted. | ||
""" | ||
if axis < 0: | ||
axis = values.ndim + axis | ||
|
||
if abs(k) >= values.shape[axis]: | ||
idx3 = np.argsort(values, axis=axis) | ||
else: | ||
idx = np.argpartition(values, -k, axis=axis) | ||
k_slice = slice(-k, None) if k > 0 else slice(-k) | ||
idx = idx[ | ||
tuple(k_slice if i == axis else slice(None) for i in range(values.ndim)) | ||
] | ||
a = np.take_along_axis(values, idx, axis) | ||
idx2 = np.argsort(a, axis=axis) | ||
idx3 = np.take_along_axis(idx, idx2, axis) | ||
if k < 0: | ||
return idx3 | ||
return idx3[ | ||
tuple( | ||
slice(None, None, -1) if i == axis else slice(None) for i in range(idx.ndim) | ||
) | ||
] | ||
|
||
|
||
nanmin = _create_method("nanmin") | ||
nanmax = _create_method("nanmax") | ||
nanmean = _create_method("nanmean") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2511,6 +2511,146 @@ def argmax( | |
""" | ||
return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) | ||
|
||
def _topk_stack( | ||
self, | ||
topk_funcname: str, | ||
dim: Dims, | ||
) -> Variable: | ||
# Get a name for the new dimension that does not conflict with any existing | ||
# dimension | ||
newdimname = f"_unravel_{topk_funcname}_dim_0" | ||
count = 1 | ||
while newdimname in self.dims: | ||
newdimname = f"_unravel_{topk_funcname}_dim_{count}" | ||
count += 1 | ||
return self.stack({newdimname: dim}) | ||
|
||
def _topk_helper( | ||
self, | ||
topk_funcname: str, | ||
k: int, | ||
dim: str, | ||
dtype: Any, | ||
keep_attrs: bool | None = None, | ||
skipna: bool | None = None, | ||
) -> Variable: | ||
from xarray.core.computation import apply_ufunc | ||
|
||
topk_func = getattr(duck_array_ops, topk_funcname) | ||
# apply_ufunc moves the dimension to the back. | ||
kwargs = {"k": k, "axis": -1, "skipna": skipna} | ||
|
||
result = apply_ufunc( | ||
topk_func, | ||
self, | ||
input_core_dims=[[dim]], | ||
exclude_dims={dim}, | ||
output_core_dims=[[topk_funcname]], | ||
output_dtypes=[dtype], | ||
dask_gufunc_kwargs=dict(output_sizes={topk_funcname: k}), | ||
dask="allowed", | ||
kwargs=kwargs, | ||
) | ||
|
||
keep_attrs_ = ( | ||
_get_keep_attrs(default=False) if keep_attrs is None else keep_attrs | ||
) | ||
|
||
if keep_attrs_: | ||
result.attrs = self._attrs | ||
return result | ||
|
||
def topk( | ||
self, | ||
k: int, | ||
dim: Dims = None, | ||
keep_attrs: bool | None = None, | ||
skipna: bool | None = None, | ||
) -> Variable | dict[Hashable, Variable]: | ||
""" | ||
TODO docstring | ||
""" | ||
# topk accepts only an integer axis like argmin or argmax, | ||
# not tuples, so we need to stack multiple dimensions. | ||
if dim is ... or dim is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# Return dimension for 1D data. | ||
if self.ndim == 1: | ||
dim = self.dims[0] | ||
else: | ||
dim = self.dims | ||
|
||
if isinstance(dim, str): | ||
stacked = self | ||
else: | ||
stacked = self._topk_stack("topk", dim) | ||
dim = stacked.dims[-1] | ||
|
||
result = stacked._topk_helper( | ||
"topk", k=k, dim=dim, dtype=self.dtype, keep_attrs=keep_attrs, skipna=skipna | ||
) | ||
return result | ||
|
||
def argtopk( | ||
self, | ||
k: int, | ||
dim: Dims = None, | ||
keep_attrs: bool | None = None, | ||
skipna: bool | None = None, | ||
) -> Variable | dict[Hashable, Variable]: | ||
""" | ||
TODO docstring | ||
""" | ||
# argtopk accepts only an integer axis like argmin or argmax, | ||
# not tuples, so we need to stack multiple dimensions. | ||
if dim is ... or dim is None: | ||
# Return dimension for 1D data. | ||
if self.ndim == 1: | ||
dim = self.dims[0] | ||
else: | ||
dim = self.dims | ||
|
||
if isinstance(dim, str): | ||
return self._topk_helper( | ||
"argtopk", | ||
k=k, | ||
dim=dim, | ||
dtype=np.intp, | ||
keep_attrs=keep_attrs, | ||
skipna=skipna, | ||
) | ||
|
||
stacked = self._topk_stack("topk", dim) | ||
newdimname = stacked.dims[-1] | ||
|
||
result_flat_indices = stacked._topk_helper( | ||
"argtopk", | ||
k=k, | ||
dim=newdimname, | ||
dtype=np.intp, | ||
keep_attrs=keep_attrs, | ||
skipna=skipna, | ||
) | ||
|
||
reduce_shape = tuple(self.sizes[d] for d in dim) | ||
|
||
result_unravelled_indices = duck_array_ops.unravel_index( | ||
result_flat_indices.data, reduce_shape | ||
) | ||
|
||
result_dims = [d for d in stacked.dims if d != newdimname] + ["argtopk"] | ||
result = { | ||
d: Variable(dims=result_dims, data=i) | ||
for d, i in zip(dim, result_unravelled_indices, strict=True) | ||
} | ||
|
||
if keep_attrs is None: | ||
keep_attrs = _get_keep_attrs(default=False) | ||
if keep_attrs: | ||
for v in result.values(): | ||
v.attrs = self.attrs | ||
|
||
return result | ||
|
||
def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable: | ||
""" | ||
Use sparse-array as backend. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
generate_aggregations.py
would make sense here so that we add it everywhere with the same docstring. That would get us groupby support for example, and I can eventually plug in flox when https://github.com/xarray-contrib/flox/pull/374/files is ready