Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Parallel chunk getitems #1

Draft
wants to merge 4 commits into
base: r2.17.0
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
import numpy as np
from numcodecs.compat import ensure_bytes

import tempfile
import logging
logger = logging.getLogger(__name__)

from zarr._storage.store import _prefix_to_attrs_key, assert_zarr_v3_api_available
from zarr.attrs import Attributes
from zarr.codecs import AsType, get_codec
Expand Down Expand Up @@ -60,8 +64,30 @@
ensure_ndarray_like,
)


from joblib import Parallel, delayed

__all__ = ["Array"]

# Number of ckeys required to trigger parallelism in _chunk_getitems
PARALLEL_THRESHOLD=12
# Number of chunks to batch per parallel worker task
PARALLEL_BATCH_SIZE=8
def parallel_io_method(instance, c_key, c_select, out_sel, drop_axes, my_out):
try:
cdata = instance.chunk_store[c_key]
chunk = instance._decode_chunk(cdata)
tmp = chunk[c_select]
if drop_axes:
tmp = np.squeeze(tmp, axis=drop_axes)
my_out[out_sel] = tmp

except Exception:
# If the read/parse failed, the file name will be in the exception.
# If the key is not present, there is no more info to log.
logger.exception("Error reading chunk %s", c_key)
my_out[out_sel] = instance._fill_value


# noinspection PyUnresolvedReferences
class Array:
Expand Down Expand Up @@ -2163,6 +2189,37 @@ def _chunk_getitems(
partial_read_decode = False
values = self.chunk_store.get_partial_values([(ckey, (0, None)) for ckey in ckeys])
cdatas = {key: value for key, value in zip(ckeys, values) if value is not None}


elif "GRIBCodec" in list(map(lambda x: str(x.__class__.__name__), self.filters or [])):
# Start parallel grib hack
# Make this really specific to GRIBCodec for now - we can make this more general later?
# Can we pass parameters for the heuristic behavior thresholds? Use module constants for now
key_count = len(ckeys)
if key_count <= PARALLEL_THRESHOLD:
logger.info("Chunk Count %s <= Parallel Threshold %s: Using Serial Chunk GetItems", key_count, PARALLEL_THRESHOLD)
for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection):
parallel_io_method(self, ckey, chunk_select, out_select, drop_axes, out)
return

logger.info("Chunk Count %s greater than Parallel Threshold %s: Using Parallel Chunk GetItems with Parallel Batch Size: %s", key_count, PARALLEL_THRESHOLD, PARALLEL_BATCH_SIZE)
# Explicitly use /dev/shm to ensure we are working in memory
with tempfile.NamedTemporaryFile(mode="w+b", prefix="zarr_memmap", dir="/dev/shm") as f:
logger.warning("Creating memmap array of shape %s, size %s - this could oom or exceed the size of /dev/shm", out.shape, out.nbytes)
output = np.memmap(f, dtype=out.dtype, shape=out.shape, mode='w+')

# Just setting mmap_mode to w+ doesn't seem to copy the data back to out...
# Hard to know batch_size without n_jobs. Use a const here too
Parallel(pre_dispatch="2*n_jobs", batch_size=PARALLEL_BATCH_SIZE)(
delayed(parallel_io_method)(self, ckey, chunk_select, out_select, drop_axes, output)
for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection)
)

out[:] = output[:]

return
# End parallel grib hack

else:
partial_read_decode = False
contexts = {}
Expand Down
Loading