From 02cd1fa424246f195856906113d3ae00d364833e Mon Sep 17 00:00:00 2001 From: Jeffrey Lai Date: Sat, 4 Jan 2025 01:13:30 -0800 Subject: [PATCH] fix bug in event function resolver --- dysts/systems.py | 5 ++--- dysts/utils/native_utils.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/dysts/systems.py b/dysts/systems.py index 2153ce0..20de82c 100644 --- a/dysts/systems.py +++ b/dysts/systems.py @@ -21,11 +21,10 @@ DynSys, DynSysDelay, ) +from .utils import num_unspecified_params Array = npt.NDArray[np.float64] -DEFAULT_RNG = np.random.default_rng() - def get_attractor_list( sys_class: str = "continuous", exclude: list[str] = [] @@ -118,7 +117,7 @@ def _resolve_event_signature( Returns: The resolved event function (solve_ivp compatible) """ - if len(inspect.signature(event).parameters) == 1: + if num_unspecified_params(event) == 1: return event(system) # type: ignore return event # type: ignore diff --git a/dysts/utils/native_utils.py b/dysts/utils/native_utils.py index 017f1a2..e65ab53 100644 --- a/dysts/utils/native_utils.py +++ b/dysts/utils/native_utils.py @@ -6,6 +6,37 @@ import os import threading import warnings +from inspect import Parameter, signature + + +def num_unspecified_params(func) -> int: + """Count required parameters that haven't been specified through partial or defaults. + + Args: + func: Function or partial function to inspect + + Returns: + Number of parameters that must be specified when calling the function + """ + if hasattr(func, "func"): # Check if partial + partial_func = func.func + partial_args = len(func.args) + partial_keywords = func.keywords or {} + else: + partial_func = func + partial_args = 0 + partial_keywords = {} + + sig = signature(partial_func) + required_count = 0 + + for i, (name, param) in enumerate(sig.parameters.items()): + if i < partial_args or name in partial_keywords: + continue + if param.default == Parameter.empty: + required_count += 1 + + return required_count def has_module(module_name: str) -> bool: