Skip to content

Commit

Permalink
Incorrect dtypes in map_selection (#669)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jan 15, 2025
1 parent 6dc00e1 commit a21a5fa
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,9 +816,9 @@ def key_function(out_key):
key_function,
x,
shapes=[shape],
dtypes=[x.dtype],
dtypes=[dtype],
chunkss=[chunks],
extra_func_kwargs=dict(func=func, dtype=dtype),
extra_func_kwargs=dict(func=func, dtype=x.dtype),
num_input_blocks=num_input_blocks,
iterable_input_blocks=iterable_input_blocks,
selection_function=selection_function,
Expand Down
18 changes: 18 additions & 0 deletions cubed/tests/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ def test_map_overlap_1d_single_chunk():
assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 4, 5, 0]))


def test_map_overlap_1d_change_dtype():
x = np.arange(6)
a = xp.asarray(x, chunks=(3,))

b = cubed.map_overlap(
lambda x: x.astype(np.float64),
a,
dtype=np.float64,
chunks=((5, 5),),
depth=1,
boundary=0,
trim=False,
)

assert b.dtype == np.float64
assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 2, 3, 4, 5, 0]))


def test_map_overlap_2d():
x = np.arange(36).reshape((6, 6))
a = xp.asarray(x, chunks=(3, 3))
Expand Down

0 comments on commit a21a5fa

Please # to comment.