Skip to content

Commit

Permalink
Merge pull request #20 from jbial/ic
Browse files Browse the repository at this point in the history
init cond functionality
  • Loading branch information
williamgilpin authored Aug 12, 2024
2 parents 1f4c374 + b101595 commit 77f278e
Showing 1 changed file with 43 additions and 15 deletions.
58 changes: 43 additions & 15 deletions dysts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
+ numba (optional, for faster integration)
"""


from dataclasses import dataclass, field, asdict
from dataclasses import dataclass, field
import warnings
import json
import collections
Expand All @@ -24,11 +22,12 @@
curr_path = sys.path[0]

import pkg_resources
from typing import Optional, Iterable, Dict, Callable

data_path_continuous = pkg_resources.resource_filename(
DATAPATH_CONTINUOUS = pkg_resources.resource_filename(
"dysts", "data/chaotic_attractors.json"
)
data_path_discrete = pkg_resources.resource_filename("dysts", "data/discrete_maps.json")
DATAPATH_DISCRETE = pkg_resources.resource_filename("dysts", "data/discrete_maps.json")


## Check for optional datasets
Expand All @@ -47,6 +46,7 @@
_has_multiprocessing = False

import numpy as np
import numpy.typing as npt

from .utils import integrate_dyn, standardize_ts
import importlib
Expand Down Expand Up @@ -249,7 +249,7 @@ class DynSys(BaseDyn):
"""

def __init__(self, **kwargs):
self.data_path = data_path_continuous
self.data_path = DATAPATH_CONTINUOUS
super().__init__(**kwargs)
self.dt = self._load_data()["dt"]
self.period = self._load_data()["period"]
Expand Down Expand Up @@ -399,7 +399,7 @@ class DynMap(BaseDyn):
"""

def __init__(self, **kwargs):
self.data_path = data_path_discrete
self.data_path = DATAPATH_DISCRETE
super().__init__(**kwargs)

def rhs(self, X):
Expand Down Expand Up @@ -647,31 +647,32 @@ def get_attractor_list(model_type="continuous"):
attractor_list (list of str): The names of all attractors in database
"""
if model_type == "continuous":
data_path = data_path_continuous
data_path = DATAPATH_CONTINUOUS
else:
data_path = data_path_discrete
data_path = DATAPATH_DISCRETE
with open(data_path, "r") as file:
data = json.load(file)
attractor_list = sorted(list(data.keys()))
return attractor_list

# flows = importlib.import_module("dysts.flows", package=".flows")
import dysts.flows as dfl
def _compute_trajectory(equation_name, n, kwargs):
"""A wrapper function for multiprocessing"""
def _compute_trajectory(equation_name, n, kwargs, init_cond=None):
"""A helper function for multiprocessing"""
eq = getattr(dfl, equation_name)()
if init_cond is not None:
eq.ic = init_cond
traj = eq.make_trajectory(n, **kwargs)
return traj

def make_trajectory_ensemble(n, subset=None, use_multiprocessing=False, random_state=None, use_tqdm=False, **kwargs):
def make_trajectory_ensemble(n, subset=None, use_multiprocessing=False, init_conds={}, use_tqdm=False, **kwargs):
"""
Integrate multiple dynamical systems with identical settings
Args:
n (int): The number of timepoints to integrate
subset (list): A list of system names. Defaults to all systems
use_multiprocessing (bool): Not yet implemented.
random_state (int): The random seed to use for the ensemble
init_cond (dict): Optional user input initial conditions mapping string system name to array
use_tqdm (bool): Whether to use a progress bar
kwargs (dict): Integration options passed to each system's make_trajectory() method
Expand All @@ -682,6 +683,9 @@ def make_trajectory_ensemble(n, subset=None, use_multiprocessing=False, random_s
if not subset:
subset = get_attractor_list()

if len(init_conds) > 0:
assert all(sys in init_conds.keys() for sys in subset), "given initial conditions must at least contain the subset"

if use_tqdm and not use_multiprocessing:
from tqdm import tqdm
subset = tqdm(subset)
Expand All @@ -694,7 +698,7 @@ def make_trajectory_ensemble(n, subset=None, use_multiprocessing=False, random_s
with Pool() as pool:
results = pool.starmap(
_compute_trajectory,
[(equation_name, n, kwargs) for equation_name in subset]
[(equation_name, n, kwargs, init_conds.get(equation_name)) for equation_name in subset]
)
all_sols = dict(zip(subset, results))

Expand All @@ -704,3 +708,27 @@ def make_trajectory_ensemble(n, subset=None, use_multiprocessing=False, random_s
all_sols[equation_name] = sol

return all_sols

def init_cond_sampler(random_seed: Optional[int] = 0, subset: Optional[Iterable] = None) -> Callable:
"""Sample zero mean guassian perturbations for each initial condition in a given system list
Args:
random_seed: for random sampling
subset: A list of system names. Defaults to all systems
Returns:
a function which samples a random perturbation of the init conditions
"""
if not subset:
subset = get_attractor_list()

rng = np.random.default_rng(random_seed)
ic_dict = {sys: np.array(getattr(dfl, sys)().ic) for sys in subset}

def _sampler(scale: Optional[float] = 1e-4) -> Dict[str, npt.NDArray[np.float64]]:
return {
sys: ic + rng.normal(scale=scale*np.linalg.norm(ic), size=ic.shape)
for sys, ic in ic_dict.items()
}

return _sampler

0 comments on commit 77f278e

Please # to comment.