diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 87800825..0ae0e97c 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -160,7 +160,7 @@ def dask_property(maybe_func: Callable, *, no_dispatch: bool = False) -> _DaskPr Parameters ---------- - maybe_func : Callable, optional + maybe_func : Callable The property getter function. no_dispatch : bool If True, re-use the main getter function as the Dask implementation. @@ -168,7 +168,7 @@ def dask_property(maybe_func: Callable, *, no_dispatch: bool = False) -> _DaskPr Returns ------- Callable - The callable dask-awkward aware descriptor factory or the descriptor itself + The dask-awkward aware property descriptor """ @@ -189,7 +189,7 @@ def dask_property( Returns ------- Callable - The callable dask-awkward aware descriptor factory or the descriptor itself + The callable dask-awkward aware property descriptor factory """ ... @@ -224,13 +224,69 @@ def dask_property_wrapper(func: Callable) -> _DaskProperty: return dask_property_wrapper(maybe_func) -def dask_method(func: F) -> F: +class _DaskMethod: + _impl: Callable + _dask_get: Callable | None = None + + def __init__(self, impl: Callable): + self._impl = impl + + def __get__( + self, instance: T | None, owner: type[T] | None = None + ) -> _DaskMethod | Callable: + if instance is None: + return self + + return self._impl.__get__(instance, owner) + + def dask(self, func: Callable) -> _DaskMethod: + self._dask_get = _make_dask_method(func) + return self + + +@overload +def dask_method(maybe_func: F, *, no_dispatch: bool = False) -> _DaskMethod: """Decorate an instance method to provide a mechanism for overriding the implementation for dask-awkward arrays via `.dask`. Parameters ---------- - func : Callable + maybe_func : Callable + The method implementation to decorate. + + Returns + ------- + Callable + The callable dask-awkward aware method. + """ + + +@overload +def dask_method( + maybe_func: None = None, *, no_dispatch: bool = False +) -> Callable[[F], _DaskMethod]: + """Decorate an instance method to provide a mechanism for overriding the + implementation for dask-awkward arrays via `.dask`. + + Parameters + ---------- + maybe_func : Callable, optional + The method implementation to decorate. + + Returns + ------- + Callable + The callable dask-awkward aware method factory. + """ + + +def dask_method(maybe_func=None, *, no_dispatch=False): + """Decorate an instance method to provide a mechanism for overriding the + implementation for dask-awkward arrays via `.dask`. + + Parameters + ---------- + maybe_func : Callable, optional The method implementation to decorate. Returns @@ -239,12 +295,18 @@ def dask_method(func: F) -> F: The callable dask-awkward aware method. """ - def dask(dask_func_impl: G) -> F: - func._dask_get = _make_dask_method(dask_func_impl) # type: ignore - return func + def dask_method_wrapper(func: F) -> _DaskMethod: + method = _DaskMethod(func) + + if no_dispatch: + return method.dask(_adapt_naive_dask_get(func)) + else: + return method - func.dask = dask # type: ignore - return func + if maybe_func is None: + return dask_method_wrapper + else: + return dask_method_wrapper(maybe_func) class Scalar(DaskMethodsMixin, DaskOperatorMethodMixin): diff --git a/tests/test_behavior.py b/tests/test_behavior.py index 8000a4a6..3a744874 100644 --- a/tests/test_behavior.py +++ b/tests/test_behavior.py @@ -46,6 +46,10 @@ def some_method(self): def some_method_dask(self, array): return array + @dak.dask_method(no_dispatch=True) + def some_method_both(self): + return "NO DISPATCH!" + @pytest.mark.xfail( BAD_NP_AK_MIXIN_VERSIONING, @@ -69,7 +73,7 @@ def test_distance_behavior( BAD_NP_AK_MIXIN_VERSIONING, reason="NumPy 1.25 mixin __slots__ change", ) -def test_property_behavior(daa_p1: dak.Array, caa_p1: ak.Array) -> None: +def test_property_method_behavior(daa_p1: dak.Array, caa_p1: ak.Array) -> None: daa = dak.with_name(daa_p1.points, name="Point", behavior=behaviors) caa = ak.Array(caa_p1.points, with_name="Point", behavior=behaviors) assert_eq(daa.x2, caa.x2) @@ -87,6 +91,7 @@ def test_property_behavior(daa_p1: dak.Array, caa_p1: ak.Array) -> None: == caa.some_property_both == "this is a dask AND non-dask property" ) + assert daa.some_method_both() == caa.some_method_both() == "NO DISPATCH!" @pytest.mark.xfail(