Skip to content

Commit

Permalink
Revert "Intermediate values use default dtypes in statistical_functio…
Browse files Browse the repository at this point in the history
…ns.py."

This reverts commit 6a2f60f.
  • Loading branch information
alxmrs committed Jan 21, 2025
1 parent a19e79f commit 11c3441
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def max(x, /, *, axis=None, keepdims=False, split_every=None):
)


def mean(x, /, *, axis=None, keepdims=False, split_every=None, device=None):
def mean(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")
# This implementation uses a Zarr group of two arrays to store a
# pair of fields needed to keep per-chunk counts and totals for computing
# the mean.
dtype = x.dtype
dtypes = __array_namespace_info__().default_dtypes(device=device)
intermediate_dtype = [("n", dtypes['integral']), ("total", dtypes['real floating'])]
#TODO(#658): Should these be default dtypes?
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype)
return reduction(
x,
Expand Down Expand Up @@ -161,15 +161,14 @@ def var(
correction=0.0,
keepdims=False,
split_every=None,
device=None,
):
# This implementation follows https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in var")
dtype = x.dtype
dtypes = __array_namespace_info__().default_dtypes(device=device)
intermediate_dtype = [("n", dtypes['integral']), ("mu", dtypes['real floating']), ("M2", dtypes['real floating'])]
#TODO(#658): Should these be default dtypes?
intermediate_dtype = [("n", nxp.int64), ("mu", nxp.float64), ("M2", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype, correction=correction)
return reduction(
x,
Expand Down

0 comments on commit 11c3441

Please # to comment.