Skip to content

Commit

Permalink
fix bug in event function resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
jbial committed Jan 4, 2025
1 parent f5d679d commit 02cd1fa
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
5 changes: 2 additions & 3 deletions dysts/systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
DynSys,
DynSysDelay,
)
from .utils import num_unspecified_params

Array = npt.NDArray[np.float64]

DEFAULT_RNG = np.random.default_rng()


def get_attractor_list(
sys_class: str = "continuous", exclude: list[str] = []
Expand Down Expand Up @@ -118,7 +117,7 @@ def _resolve_event_signature(
Returns:
The resolved event function (solve_ivp compatible)
"""
if len(inspect.signature(event).parameters) == 1:
if num_unspecified_params(event) == 1:
return event(system) # type: ignore
return event # type: ignore

Expand Down
31 changes: 31 additions & 0 deletions dysts/utils/native_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,37 @@
import os
import threading
import warnings
from inspect import Parameter, signature


def num_unspecified_params(func) -> int:
"""Count required parameters that haven't been specified through partial or defaults.
Args:
func: Function or partial function to inspect
Returns:
Number of parameters that must be specified when calling the function
"""
if hasattr(func, "func"): # Check if partial
partial_func = func.func
partial_args = len(func.args)
partial_keywords = func.keywords or {}
else:
partial_func = func
partial_args = 0
partial_keywords = {}

sig = signature(partial_func)
required_count = 0

for i, (name, param) in enumerate(sig.parameters.items()):
if i < partial_args or name in partial_keywords:
continue
if param.default == Parameter.empty:
required_count += 1

return required_count


def has_module(module_name: str) -> bool:
Expand Down

0 comments on commit 02cd1fa

Please # to comment.