diff --git a/zarr/core.py b/zarr/core.py index d22a9d79c3..4c288b3521 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -10,6 +10,10 @@ import numpy as np from numcodecs.compat import ensure_bytes +import tempfile +import logging +logger = logging.getLogger(__name__) + from zarr._storage.store import _prefix_to_attrs_key, assert_zarr_v3_api_available from zarr.attrs import Attributes from zarr.codecs import AsType, get_codec @@ -60,8 +64,30 @@ ensure_ndarray_like, ) + +from joblib import Parallel, delayed + __all__ = ["Array"] +# Number of ckeys required to trigger parallelism in _chunk_getitems +PARALLEL_THRESHOLD=12 +# Number of chunks to batch per parallel worker task +PARALLEL_BATCH_SIZE=8 +def parallel_io_method(instance, c_key, c_select, out_sel, drop_axes, my_out): + try: + cdata = instance.chunk_store[c_key] + chunk = instance._decode_chunk(cdata) + tmp = chunk[c_select] + if drop_axes: + tmp = np.squeeze(tmp, axis=drop_axes) + my_out[out_sel] = tmp + + except Exception: + # If the read/parse failed, the file name will be in the exception. + # If the key is not present, there is no more info to log. + logger.exception("Error reading chunk %s", c_key) + my_out[out_sel] = instance._fill_value + # noinspection PyUnresolvedReferences class Array: @@ -2163,6 +2189,37 @@ def _chunk_getitems( partial_read_decode = False values = self.chunk_store.get_partial_values([(ckey, (0, None)) for ckey in ckeys]) cdatas = {key: value for key, value in zip(ckeys, values) if value is not None} + + + elif "GRIBCodec" in list(map(lambda x: str(x.__class__.__name__), self.filters or [])): + # Start parallel grib hack + # Make this really specific to GRIBCodec for now - we can make this more general later? + # Can we pass parameters for the heuristic behavior thresholds? Use module constants for now + key_count = len(ckeys) + if key_count <= PARALLEL_THRESHOLD: + logger.info("Chunk Count %s <= Parallel Threshold %s: Using Serial Chunk GetItems", key_count, PARALLEL_THRESHOLD) + for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection): + parallel_io_method(self, ckey, chunk_select, out_select, drop_axes, out) + return + + logger.info("Chunk Count %s greater than Parallel Threshold %s: Using Parallel Chunk GetItems with Parallel Batch Size: %s", key_count, PARALLEL_THRESHOLD, PARALLEL_BATCH_SIZE) + # Explicitly use /dev/shm to ensure we are working in memory + with tempfile.NamedTemporaryFile(mode="w+b", prefix="zarr_memmap", dir="/dev/shm") as f: + logger.warning("Creating memmap array of shape %s, size %s - this could oom or exceed the size of /dev/shm", out.shape, out.nbytes) + output = np.memmap(f, dtype=out.dtype, shape=out.shape, mode='w+') + + # Just setting mmap_mode to w+ doesn't seem to copy the data back to out... + # Hard to know batch_size without n_jobs. Use a const here too + Parallel(pre_dispatch="2*n_jobs", batch_size=PARALLEL_BATCH_SIZE)( + delayed(parallel_io_method)(self, ckey, chunk_select, out_select, drop_axes, output) + for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection) + ) + + out[:] = output[:] + + return + # End parallel grib hack + else: partial_read_decode = False contexts = {}