Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
williamgilpin committed Nov 17, 2024
2 parents 2ace45a + 804bf39 commit b9b6b48
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions dysts/systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from multiprocessing import Pool
from os import PathLike
from types import ModuleType
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -100,7 +100,7 @@ def get_system_data(


def _compute_trajectory(
equation_name: str,
system: Union[str, BaseDyn],
n: int,
kwargs: Dict[str, Any],
ic_transform: Optional[BaseSampler] = None,
Expand All @@ -109,19 +109,22 @@ def _compute_trajectory(
param_rng: Optional[np.random.Generator] = None,
) -> Array:
"""A helper function for multiprocessing"""
eq = getattr(dfl, equation_name)()
if isinstance(system, str):
eq = getattr(dfl, system)()
else:
eq = system

if param_transform is not None:
if param_rng is not None:
param_transform.set_rng(param_rng)
eq.transform_params(param_transform)
eq.transform_params(param_transform) # type: ignore

# the initial condition transform must come after the parameter transform
# because suitable initial conditions may depend on the parameters
if ic_transform is not None:
if ic_rng is not None:
ic_transform.set_rng(ic_rng)
eq.transform_ic(ic_transform)
eq.transform_ic(ic_transform) # type: ignore

traj = eq.make_trajectory(n, **kwargs)
return traj
Expand All @@ -133,7 +136,7 @@ def make_trajectory_ensemble(
use_multiprocessing: bool = False,
ic_transform: Optional[BaseSampler] = None,
param_transform: Optional[BaseSampler] = None,
subset: Optional[Sequence[str]] = None,
subset: Optional[Union[Sequence[str], Sequence[BaseDyn]]] = None,
ic_rng: Optional[np.random.Generator] = None,
param_rng: Optional[np.random.Generator] = None,
**kwargs,
Expand All @@ -149,7 +152,7 @@ def make_trajectory_ensemble(
use_multiprocessing (bool): Not yet implemented.
ic_transform (callable): function that transforms individual system initial conditions
param_transform (callable): function that transforms individual system parameters
subset (list): A list of system names. Defaults to all continuous systems.
subset (list): A list of system names or BaseDyn (e.g. custom dynamical systems). Defaults to all continuous systems.
Can also pass in `sys_class` as a kwarg to specify other system classes.
kwargs (dict): Integration options passed to each system's make_trajectory() method
Expand All @@ -172,18 +175,17 @@ def make_trajectory_ensemble(
)
else:
# stupid lint error fix for subset being possibly None
for equation_name in subset or []:
sol = _compute_trajectory(
equation_name, n, kwargs, ic_transform, param_transform
)
for system in subset or []:
sol = _compute_trajectory(system, n, kwargs, ic_transform, param_transform)
equation_name = system if isinstance(system, str) else type(system).__name__
all_sols[equation_name] = sol

return all_sols


def _multiprocessed_compute_trajectory(
n: int,
subset: Sequence[str],
subset: Union[Sequence[str], Sequence[BaseDyn]],
ic_transform: Optional[BaseSampler] = None,
param_transform: Optional[BaseSampler] = None,
ic_rng: Optional[np.random.Generator] = None,
Expand Down Expand Up @@ -228,7 +230,8 @@ def _multiprocessed_compute_trajectory(
],
)

return dict(zip(subset, results))
names = [name if isinstance(name, str) else type(name).__name__ for name in subset]
return dict(zip(names, results))


def compute_trajectory_statistics(
Expand Down

0 comments on commit b9b6b48

Please # to comment.