From c274906733a03e52a640ed77cdca4a0d26839ee8 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:56:57 -0600 Subject: [PATCH] Feature/sbachmei/mic 5127 update standard observers (#441) * reset index in default results_formatter * handle some .astype() warnings * more observers configuration_defaults stratification name into getter method --- src/vivarium/framework/results/interface.py | 2 +- src/vivarium/framework/results/observation.py | 8 +++++--- src/vivarium/framework/results/observer.py | 5 ++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index 69f8ac842..c69d50e33 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -319,7 +319,7 @@ def register_adding_observation( requires_values: List[str] = [], results_formatter: Callable[ [str, pd.DataFrame], pd.DataFrame - ] = lambda measure, results: results, + ] = lambda measure, results: results.reset_index(), additional_stratifications: List[str] = [], excluded_stratifications: List[str] = [], aggregator_sources: Optional[List[str]] = None, diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index 0ef17de5c..7958ba2c9 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -136,11 +136,13 @@ def initialize_results( df = pd.DataFrame( list(itertools.product(*stratification_values.values())), columns=stratification_names, - ).astype(CategoricalDtype) + ).astype(CategoricalDtype()) else: # We are aggregating the entire population so create a single-row index stratification_names = ["stratification"] - df = pd.DataFrame(["all"], columns=stratification_names).astype(CategoricalDtype) + df = pd.DataFrame(["all"], columns=stratification_names).astype( + CategoricalDtype() + ) # Initialize a zeros dataframe df[VALUE_COLUMN] = 0.0 @@ -170,7 +172,7 @@ def _aggregate( pop_groups[aggregator_sources].apply(aggregator).fillna(0.0) if aggregator_sources else pop_groups.apply(aggregator) - ) + ).astype(float) return aggregates @staticmethod diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 4efbb4536..417173e0b 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -21,13 +21,16 @@ def __init__(self) -> None: def configuration_defaults(self) -> Dict[str, Any]: return { "stratification": { - self.name.split("_observer")[0]: { + self.get_configuration_name(): { "exclude": [], "include": [], }, }, } + def get_configuration_name(self) -> str: + return self.name.split("_observer")[0] + @abstractmethod def register_observations(self, builder: Builder) -> None: """(Required). Register observations with within each observer."""