diff --git a/src/dask_awkward/lib/inspect.py b/src/dask_awkward/lib/inspect.py index e5f62602a..860fcaae8 100644 --- a/src/dask_awkward/lib/inspect.py +++ b/src/dask_awkward/lib/inspect.py @@ -4,7 +4,6 @@ import numpy as np from dask.base import unpack_collections -from dask.highlevelgraph import HighLevelGraph from dask_awkward.layers import AwkwardInputLayer @@ -81,8 +80,9 @@ def report_necessary_buffers( name_to_necessary_buffers: dict[str, NecessaryBuffers | None] = {} for obj in collections: - dsk = obj if isinstance(obj, HighLevelGraph) else obj.dask - projection_data = o._prepare_buffer_projection(dsk) + dsk = obj.__dask_graph__() + keys = obj.__dask_keys__() + projection_data = o._prepare_buffer_projection(dsk, keys) # If the projection failed, or there are no input layers if projection_data is None: @@ -178,8 +178,9 @@ def report_necessary_columns( name_to_necessary_columns: dict[str, frozenset | None] = {} for obj in collections: - dsk = obj if isinstance(obj, HighLevelGraph) else obj.dask - projection_data = o._prepare_buffer_projection(dsk) + dsk = obj.__dask_graph__() + keys = obj.__dask_keys__() + projection_data = o._prepare_buffer_projection(dsk, keys) # If the projection failed, or there are no input layers if projection_data is None: diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 2f971873b..37b38d410 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -3,7 +3,7 @@ import copy import logging import warnings -from collections.abc import Hashable, Iterable, Mapping +from collections.abc import Hashable, Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, cast import dask.config @@ -17,6 +17,7 @@ if TYPE_CHECKING: from awkward._nplikes.typetracer import TypeTracerReport + from dask.typing import Key log = logging.getLogger(__name__) @@ -65,7 +66,7 @@ def all_optimizations( def optimize( dsk: HighLevelGraph, - keys: Hashable | list[Hashable] | set[Hashable], + keys: Sequence[Key], **_: Any, ) -> Mapping: """Run optimizations specific to dask-awkward. @@ -77,7 +78,7 @@ def optimize( if dask.config.get("awkward.optimization.enabled"): which = dask.config.get("awkward.optimization.which") if "columns" in which: - dsk = optimize_columns(dsk) + dsk = optimize_columns(dsk, keys) if "layer-chains" in which: dsk = rewrite_layer_chains(dsk, keys) @@ -85,7 +86,7 @@ def optimize( def _prepare_buffer_projection( - dsk: HighLevelGraph, + dsk: HighLevelGraph, keys: Sequence[Key] ) -> tuple[dict[str, TypeTracerReport], dict[str, Any]] | None: """Pair layer names with lists of necessary columns.""" import awkward as ak @@ -117,18 +118,6 @@ def _prepare_buffer_projection( hlg = HighLevelGraph(projection_layers, dsk.dependencies) - # this loop builds up what are the possible final leaf nodes by - # inspecting the dependents dictionary. If something does not have - # a dependent, it must be the end of a graph. These are the things - # we need to compute for; we only use a single partition (the - # first). for a single collection `.compute()` this list will just - # be length 1; but if we are using `dask.compute` to pass in - # multiple collections to be computed simultaneously, this list - # will increase in length. - leaf_layers_keys = [ - (k, 0) for k, v in dsk.dependents.items() if isinstance(v, set) and len(v) == 0 - ] - # now we try to compute for each possible output layer key (leaf # node on partition 0); this will cause the typetacer reports to # get correct fields/columns touched. If the result is a record or @@ -136,7 +125,7 @@ def _prepare_buffer_projection( try: for layer in hlg.layers.values(): layer.__dict__.pop("_cached_dict", None) - results = get_sync(hlg, leaf_layers_keys) + results = get_sync(hlg, list(keys)) for out in results: if isinstance(out, (ak.Array, ak.Record)): touch_data(out) @@ -163,7 +152,7 @@ def _prepare_buffer_projection( return layer_to_reports, layer_to_projection_state -def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph: +def optimize_columns(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph: """Run column projection optimization. This optimization determines which columns from an @@ -192,7 +181,7 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph: New, optimized task graph with column-projected ``AwkwardInputLayer``. """ - projection_data = _prepare_buffer_projection(dsk) + projection_data = _prepare_buffer_projection(dsk, keys) if projection_data is None: return dsk @@ -258,7 +247,7 @@ def _mock_output(layer): return new_layer -def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph: +def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph: """Smush chains of blockwise layers into a single layer. The logic here identifies chains by popping layers (in arbitrary @@ -292,7 +281,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph: chains = [] deps = copy.copy(dsk.dependencies) - required_layers = {k[0] for k in keys} + required_layers = {k[0] for k in keys if isinstance(k, tuple)} layers = {} # find chains; each chain list is at least two keys long dependents = dsk.dependents