From 804bf39cddb0c0ab92800a27ebb9e4f7e45fa319 Mon Sep 17 00:00:00 2001 From: Jeffrey Lai Date: Wed, 13 Nov 2024 14:26:42 -0600 Subject: [PATCH] allow make_trajectory_ensemble to accept a subset sequence of BaseDyn objects instead of only strings; useful for custom systems --- dysts/systems.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/dysts/systems.py b/dysts/systems.py index 549c920..c6ef14b 100644 --- a/dysts/systems.py +++ b/dysts/systems.py @@ -5,7 +5,7 @@ from multiprocessing import Pool from os import PathLike from types import ModuleType -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import numpy.typing as npt @@ -100,7 +100,7 @@ def get_system_data( def _compute_trajectory( - equation_name: str, + system: Union[str, BaseDyn], n: int, kwargs: Dict[str, Any], ic_transform: Optional[BaseSampler] = None, @@ -109,19 +109,22 @@ def _compute_trajectory( param_rng: Optional[np.random.Generator] = None, ) -> Array: """A helper function for multiprocessing""" - eq = getattr(dfl, equation_name)() + if isinstance(system, str): + eq = getattr(dfl, system)() + else: + eq = system if param_transform is not None: if param_rng is not None: param_transform.set_rng(param_rng) - eq.transform_params(param_transform) + eq.transform_params(param_transform) # type: ignore # the initial condition transform must come after the parameter transform # because suitable initial conditions may depend on the parameters if ic_transform is not None: if ic_rng is not None: ic_transform.set_rng(ic_rng) - eq.transform_ic(ic_transform) + eq.transform_ic(ic_transform) # type: ignore traj = eq.make_trajectory(n, **kwargs) return traj @@ -133,7 +136,7 @@ def make_trajectory_ensemble( use_multiprocessing: bool = False, ic_transform: Optional[BaseSampler] = None, param_transform: Optional[BaseSampler] = None, - subset: Optional[Sequence[str]] = None, + subset: Optional[Union[Sequence[str], Sequence[BaseDyn]]] = None, ic_rng: Optional[np.random.Generator] = None, param_rng: Optional[np.random.Generator] = None, **kwargs, @@ -149,7 +152,7 @@ def make_trajectory_ensemble( use_multiprocessing (bool): Not yet implemented. ic_transform (callable): function that transforms individual system initial conditions param_transform (callable): function that transforms individual system parameters - subset (list): A list of system names. Defaults to all continuous systems. + subset (list): A list of system names or BaseDyn (e.g. custom dynamical systems). Defaults to all continuous systems. Can also pass in `sys_class` as a kwarg to specify other system classes. kwargs (dict): Integration options passed to each system's make_trajectory() method @@ -172,10 +175,9 @@ def make_trajectory_ensemble( ) else: # stupid lint error fix for subset being possibly None - for equation_name in subset or []: - sol = _compute_trajectory( - equation_name, n, kwargs, ic_transform, param_transform - ) + for system in subset or []: + sol = _compute_trajectory(system, n, kwargs, ic_transform, param_transform) + equation_name = system if isinstance(system, str) else type(system).__name__ all_sols[equation_name] = sol return all_sols @@ -183,7 +185,7 @@ def make_trajectory_ensemble( def _multiprocessed_compute_trajectory( n: int, - subset: Sequence[str], + subset: Union[Sequence[str], Sequence[BaseDyn]], ic_transform: Optional[BaseSampler] = None, param_transform: Optional[BaseSampler] = None, ic_rng: Optional[np.random.Generator] = None, @@ -228,7 +230,8 @@ def _multiprocessed_compute_trajectory( ], ) - return dict(zip(subset, results)) + names = [name if isinstance(name, str) else type(name).__name__ for name in subset] + return dict(zip(names, results)) def compute_trajectory_statistics(