From fc39280904bcbce53a1583e71ecfd66db73b999f Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 16 Jul 2024 14:36:12 +0200 Subject: [PATCH 1/2] remove functionality of batching submission in client.map --- distributed/client.py | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index b5c6eb3789e..f3f8230bb0e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -41,7 +41,7 @@ from typing_extensions import TypeAlias from packaging.version import parse as parse_version -from tlz import first, groupby, merge, partition_all, valmap +from tlz import first, groupby, merge, valmap import dask from dask.base import collections_to_dsk, tokenize @@ -2291,34 +2291,9 @@ def map( ) total_length = sum(len(x) for x in iterables) if batch_size and batch_size > 1 and total_length > batch_size: - batches = list( - zip(*(partition_all(batch_size, iterable) for iterable in iterables)) - ) - keys: list[list[Any]] | list[Any] - if isinstance(key, list): - keys = [list(element) for element in partition_all(batch_size, key)] - else: - keys = [key for _ in range(len(batches))] - return sum( - ( - self.map( - func, - *batch, - key=key, - workers=workers, - retries=retries, - priority=priority, - allow_other_workers=allow_other_workers, - fifo_timeout=fifo_timeout, - resources=resources, - actor=actor, - actors=actors, - pure=pure, - **kwargs, - ) - for key, batch in zip(keys, batches) - ), - [], + warnings.warn( + 'The argument "batch_size" is ignored and will be removed in a future version.', + DeprecationWarning, ) key = key or funcname(func) From a04a25ffbe6367422ba597e26ba59cd552d7df19 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 16 Jul 2024 14:39:28 +0200 Subject: [PATCH 2/2] assert on deprecation warning --- distributed/tests/test_client.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c3c40519299..efdb3dab12e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -243,16 +243,20 @@ async def test_map_retries(c, s, a, b): @gen_cluster(client=True) async def test_map_batch_size(c, s, a, b): - result = c.map(inc, range(100), batch_size=10) + with pytest.deprecated_call(match="batch_size"): + result = c.map(inc, range(100), batch_size=10) result = await c.gather(result) assert result == list(range(1, 101)) - result = c.map(add, range(100), range(100), batch_size=10) + with pytest.deprecated_call(match="batch_size"): + result = c.map(add, range(100), range(100), batch_size=10) result = await c.gather(result) assert result == list(range(0, 200, 2)) # mismatch shape - result = c.map(add, range(100, 200), range(10), batch_size=2) + + with pytest.deprecated_call(match="batch_size"): + result = c.map(add, range(100, 200), range(10), batch_size=2) result = await c.gather(result) assert result == list(range(100, 120, 2)) @@ -261,12 +265,13 @@ async def test_map_batch_size(c, s, a, b): async def test_custom_key_with_batches(c, s, a, b): """Test of """ - futs = c.map( - lambda x: x**2, - range(10), - batch_size=5, - key=[str(x) for x in range(10)], - ) + with pytest.deprecated_call(match="batch_size"): + futs = c.map( + lambda x: x**2, + range(10), + batch_size=5, + key=[str(x) for x in range(10)], + ) assert len(futs) == 10 await wait(futs)