diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 37202559..d9154183 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -816,9 +816,9 @@ def key_function(out_key): key_function, x, shapes=[shape], - dtypes=[x.dtype], + dtypes=[dtype], chunkss=[chunks], - extra_func_kwargs=dict(func=func, dtype=dtype), + extra_func_kwargs=dict(func=func, dtype=x.dtype), num_input_blocks=num_input_blocks, iterable_input_blocks=iterable_input_blocks, selection_function=selection_function, diff --git a/cubed/tests/test_overlap.py b/cubed/tests/test_overlap.py index 5c29e3a7..a388a3ca 100644 --- a/cubed/tests/test_overlap.py +++ b/cubed/tests/test_overlap.py @@ -39,6 +39,24 @@ def test_map_overlap_1d_single_chunk(): assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 4, 5, 0])) +def test_map_overlap_1d_change_dtype(): + x = np.arange(6) + a = xp.asarray(x, chunks=(3,)) + + b = cubed.map_overlap( + lambda x: x.astype(np.float64), + a, + dtype=np.float64, + chunks=((5, 5),), + depth=1, + boundary=0, + trim=False, + ) + + assert b.dtype == np.float64 + assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 2, 3, 4, 5, 0])) + + def test_map_overlap_2d(): x = np.arange(36).reshape((6, 6)) a = xp.asarray(x, chunks=(3, 3))