Skip to content

Remove functionality of batching submission in client.map #8771

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 4 additions & 29 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from typing_extensions import TypeAlias

from packaging.version import parse as parse_version
from tlz import first, groupby, merge, partition_all, valmap
from tlz import first, groupby, merge, valmap

import dask
from dask.base import collections_to_dsk, tokenize
Expand Down Expand Up @@ -2291,34 +2291,9 @@ def map(
)
total_length = sum(len(x) for x in iterables)
if batch_size and batch_size > 1 and total_length > batch_size:
batches = list(
zip(*(partition_all(batch_size, iterable) for iterable in iterables))
)
keys: list[list[Any]] | list[Any]
if isinstance(key, list):
keys = [list(element) for element in partition_all(batch_size, key)]
else:
keys = [key for _ in range(len(batches))]
return sum(
(
self.map(
func,
*batch,
key=key,
workers=workers,
retries=retries,
priority=priority,
allow_other_workers=allow_other_workers,
fifo_timeout=fifo_timeout,
resources=resources,
actor=actor,
actors=actors,
pure=pure,
**kwargs,
)
for key, batch in zip(keys, batches)
),
[],
warnings.warn(
'The argument "batch_size" is ignored and will be removed in a future version.',
DeprecationWarning,
)

key = key or funcname(func)
Expand Down
23 changes: 14 additions & 9 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,20 @@ async def test_map_retries(c, s, a, b):

@gen_cluster(client=True)
async def test_map_batch_size(c, s, a, b):
result = c.map(inc, range(100), batch_size=10)
with pytest.deprecated_call(match="batch_size"):
result = c.map(inc, range(100), batch_size=10)
result = await c.gather(result)
assert result == list(range(1, 101))

result = c.map(add, range(100), range(100), batch_size=10)
with pytest.deprecated_call(match="batch_size"):
result = c.map(add, range(100), range(100), batch_size=10)
result = await c.gather(result)
assert result == list(range(0, 200, 2))

# mismatch shape
result = c.map(add, range(100, 200), range(10), batch_size=2)

with pytest.deprecated_call(match="batch_size"):
result = c.map(add, range(100, 200), range(10), batch_size=2)
result = await c.gather(result)
assert result == list(range(100, 120, 2))

Expand All @@ -261,12 +265,13 @@ async def test_map_batch_size(c, s, a, b):
async def test_custom_key_with_batches(c, s, a, b):
"""Test of <https://github.com/dask/distributed/issues/4588>"""

futs = c.map(
lambda x: x**2,
range(10),
batch_size=5,
key=[str(x) for x in range(10)],
)
with pytest.deprecated_call(match="batch_size"):
futs = c.map(
lambda x: x**2,
range(10),
batch_size=5,
key=[str(x) for x in range(10)],
)
assert len(futs) == 10
await wait(futs)

Expand Down
Loading