From 75b9b132cfb26d6ebbc3709eecb42688d6131711 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 12 Jun 2020 12:37:09 +0200 Subject: [PATCH 1/3] replace the object array with generator expressions and zip/enumerate --- xarray/core/parallel.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 522c5b36ff5..df36c3dff5f 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -34,11 +34,8 @@ T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) -def to_object_array(iterable): - # using empty_like calls compute - npargs = np.empty((len(iterable),), dtype=np.object) - npargs[:] = iterable - return npargs +def unzip(iterable): + return zip(*iterable) def assert_chunks_compatible(a: Dataset, b: Dataset): @@ -335,23 +332,33 @@ def _wrapper( if not dask.is_dask_collection(obj): return func(obj, *args, **kwargs) - npargs = to_object_array([obj] + list(args)) - is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs] - is_array = [isinstance(arg, DataArray) for arg in npargs] + all_args = [obj] + list(args) + is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in all_args] + is_array = [isinstance(arg, DataArray) for arg in all_args] + + # there should be a better way to group this. partition? + xarray_indices, xarray_objs = unzip( + (index, arg) for index, arg in enumerate(all_args) if is_xarray[index] + ) + others = [ + (index, arg) for index, arg in enumerate(all_args) if not is_xarray[index] + ] # all xarray objects must be aligned. This is consistent with apply_ufunc. - aligned = align(*npargs[is_xarray], join="exact") - # assigning to object arrays works better when RHS is object array - # https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array - npargs[is_xarray] = to_object_array(aligned) - npargs[is_array] = to_object_array( - [dataarray_to_dataset(da) for da in npargs[is_array]] + aligned = align(*xarray_objs, join="exact") + xarray_objs = tuple( + dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg + for arg in aligned + ) + + _, npargs = unzip( + sorted((list(zip(xarray_indices, xarray_objs)) + others), key=lambda x: x[0]) ) # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) input_indexes = dict(npargs[0].indexes) - for arg in npargs[1:][is_xarray[1:]]: + for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) input_indexes.update(arg.indexes) From e94bca79081a8f687e7fef4a44c98b00d240f5f9 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 12 Jun 2020 13:19:02 +0200 Subject: [PATCH 2/3] remove a leftover grouping pair of parentheses --- xarray/core/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index df36c3dff5f..960fc684e96 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -352,7 +352,7 @@ def _wrapper( ) _, npargs = unzip( - sorted((list(zip(xarray_indices, xarray_objs)) + others), key=lambda x: x[0]) + sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) ) # check that chunk sizes are compatible From 6f12b72b7d204443b38417000ae17f9820390adf Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 12 Jun 2020 13:21:35 +0200 Subject: [PATCH 3/3] reuse is_array instead of comparing again --- xarray/core/parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 960fc684e96..3a77753d0d1 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -347,8 +347,8 @@ def _wrapper( # all xarray objects must be aligned. This is consistent with apply_ufunc. aligned = align(*xarray_objs, join="exact") xarray_objs = tuple( - dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg - for arg in aligned + dataarray_to_dataset(arg) if is_da else arg + for is_da, arg in zip(is_array, aligned) ) _, npargs = unzip(