Skip to content

Commit

Permalink
feat: add no_dispatch to dask_method (#429)
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 authored Dec 5, 2023
1 parent b28e901 commit daf0dbe
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 11 deletions.
88 changes: 78 additions & 10 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ def dask_property(maybe_func: Callable, *, no_dispatch: bool = False) -> _DaskPr
Parameters
----------
maybe_func : Callable, optional
maybe_func : Callable
The property getter function.
no_dispatch : bool
If True, re-use the main getter function as the Dask implementation.
Returns
-------
Callable
The callable dask-awkward aware descriptor factory or the descriptor itself
The dask-awkward aware property descriptor
"""


Expand All @@ -189,7 +189,7 @@ def dask_property(
Returns
-------
Callable
The callable dask-awkward aware descriptor factory or the descriptor itself
The callable dask-awkward aware property descriptor factory
"""
...

Expand Down Expand Up @@ -224,27 +224,95 @@ def dask_property_wrapper(func: Callable) -> _DaskProperty:
return dask_property_wrapper(maybe_func)


def dask_method(func: F) -> F:
class _DaskMethod:
_impl: Callable
_dask_get: Callable | None = None

def __init__(self, impl: Callable):
self._impl = impl

def __get__(
self, instance: T | None, owner: type[T] | None = None
) -> _DaskMethod | Callable:
if instance is None:
return self

return self._impl.__get__(instance, owner)

def dask(self, func: Callable) -> _DaskMethod:
self._dask_get = _make_dask_method(func)
return self


@overload
def dask_method(maybe_func: F, *, no_dispatch: bool = False) -> _DaskMethod:
"""Decorate an instance method to provide a mechanism for overriding the
implementation for dask-awkward arrays via `.dask`.
Parameters
----------
func : Callable
maybe_func : Callable
The method implementation to decorate.
no_dispatch : bool
If True, re-use the main getter function as the Dask implementation
Returns
-------
Callable
The callable dask-awkward aware method.
"""


@overload
def dask_method(
maybe_func: None = None, *, no_dispatch: bool = False
) -> Callable[[F], _DaskMethod]:
"""Decorate an instance method to provide a mechanism for overriding the
implementation for dask-awkward arrays via `.dask`.
Parameters
----------
maybe_func : Callable, optional
The method implementation to decorate.
no_dispatch : bool
If True, re-use the main getter function as the Dask implementation
Returns
-------
Callable
The callable dask-awkward aware method factory.
"""


def dask_method(maybe_func=None, *, no_dispatch=False):
"""Decorate an instance method to provide a mechanism for overriding the
implementation for dask-awkward arrays via `.dask`.
Parameters
----------
maybe_func : Callable, optional
The method implementation to decorate.
no_dispatch : bool
If True, re-use the main getter function as the Dask implementation
Returns
-------
Callable
The callable dask-awkward aware method.
"""

def dask(dask_func_impl: G) -> F:
func._dask_get = _make_dask_method(dask_func_impl) # type: ignore
return func
def dask_method_wrapper(func: F) -> _DaskMethod:
method = _DaskMethod(func)

func.dask = dask # type: ignore
return func
if no_dispatch:
return method.dask(_adapt_naive_dask_get(func))
else:
return method

if maybe_func is None:
return dask_method_wrapper
else:
return dask_method_wrapper(maybe_func)


class Scalar(DaskMethodsMixin, DaskOperatorMethodMixin):
Expand Down
7 changes: 6 additions & 1 deletion tests/test_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def some_method(self):
def some_method_dask(self, array):
return array

@dak.dask_method(no_dispatch=True)
def some_method_both(self):
return "NO DISPATCH!"


@pytest.mark.xfail(
BAD_NP_AK_MIXIN_VERSIONING,
Expand All @@ -69,7 +73,7 @@ def test_distance_behavior(
BAD_NP_AK_MIXIN_VERSIONING,
reason="NumPy 1.25 mixin __slots__ change",
)
def test_property_behavior(daa_p1: dak.Array, caa_p1: ak.Array) -> None:
def test_property_method_behavior(daa_p1: dak.Array, caa_p1: ak.Array) -> None:
daa = dak.with_name(daa_p1.points, name="Point", behavior=behaviors)
caa = ak.Array(caa_p1.points, with_name="Point", behavior=behaviors)
assert_eq(daa.x2, caa.x2)
Expand All @@ -87,6 +91,7 @@ def test_property_behavior(daa_p1: dak.Array, caa_p1: ak.Array) -> None:
== caa.some_property_both
== "this is a dask AND non-dask property"
)
assert daa.some_method_both() == caa.some_method_both() == "NO DISPATCH!"


@pytest.mark.xfail(
Expand Down

0 comments on commit daf0dbe

Please # to comment.