From a21a5fae5e7f488fd504070e5ee157bc04643c9b Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 15 Jan 2025 11:53:13 +0000 Subject: [PATCH] Incorrect dtypes in `map_selection` (#669) --- cubed/core/ops.py | 4 ++-- cubed/tests/test_overlap.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 372025596..d9154183d 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -816,9 +816,9 @@ def key_function(out_key): key_function, x, shapes=[shape], - dtypes=[x.dtype], + dtypes=[dtype], chunkss=[chunks], - extra_func_kwargs=dict(func=func, dtype=dtype), + extra_func_kwargs=dict(func=func, dtype=x.dtype), num_input_blocks=num_input_blocks, iterable_input_blocks=iterable_input_blocks, selection_function=selection_function, diff --git a/cubed/tests/test_overlap.py b/cubed/tests/test_overlap.py index 5c29e3a76..a388a3caa 100644 --- a/cubed/tests/test_overlap.py +++ b/cubed/tests/test_overlap.py @@ -39,6 +39,24 @@ def test_map_overlap_1d_single_chunk(): assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 4, 5, 0])) +def test_map_overlap_1d_change_dtype(): + x = np.arange(6) + a = xp.asarray(x, chunks=(3,)) + + b = cubed.map_overlap( + lambda x: x.astype(np.float64), + a, + dtype=np.float64, + chunks=((5, 5),), + depth=1, + boundary=0, + trim=False, + ) + + assert b.dtype == np.float64 + assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 2, 3, 4, 5, 0])) + + def test_map_overlap_2d(): x = np.arange(36).reshape((6, 6)) a = xp.asarray(x, chunks=(3, 3))