diff --git a/src/vivarium/framework/results/__init__.py b/src/vivarium/framework/results/__init__.py index 77c93d311..af4e8a073 100644 --- a/src/vivarium/framework/results/__init__.py +++ b/src/vivarium/framework/results/__init__.py @@ -1,3 +1,4 @@ from vivarium.framework.results.interface import ResultsInterface -from vivarium.framework.results.manager import VALUE_COLUMN, ResultsManager +from vivarium.framework.results.manager import ResultsManager +from vivarium.framework.results.observation import VALUE_COLUMN from vivarium.framework.results.observer import Observer diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index 9e8a18eb6..8687da481 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -1,17 +1,13 @@ from __future__ import annotations -import itertools from collections import defaultdict from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import pandas as pd -from pandas.api.types import CategoricalDtype from vivarium.framework.event import Event from vivarium.framework.results.context import ResultsContext -from vivarium.framework.results.observation import ConcatenatingObservation -from vivarium.framework.results.stratification import Stratification from vivarium.framework.values import Pipeline from vivarium.manager import Manager @@ -19,9 +15,6 @@ from vivarium.framework.engine import Builder -VALUE_COLUMN = "value" - - class SourceType(Enum): COLUMN = 0 VALUE = 1 @@ -89,48 +82,35 @@ def setup(self, builder: "Builder") -> None: def on_post_setup(self, _: Event) -> None: """Initialize results with 0s DataFrame' for each measure and all stratifications""" registered_stratifications = self._results_context.stratifications - registered_stratification_names = set( - stratification.name for stratification in registered_stratifications - ) - - missing_stratifications = {} - unused_stratifications = registered_stratification_names.copy() + used_stratifications = set() for event_name in self._results_context.observations: for ( _pop_filter, - all_requested_stratification_names, + event_requested_stratification_names, ), observations in self._results_context.observations[event_name].items(): + if event_requested_stratification_names is not None: + used_stratifications |= set(event_requested_stratification_names) for observation in observations: measure = observation.name - if all_requested_stratification_names is not None: - df, unused_stratifications = self._initialize_stratified_results( - measure, - all_requested_stratification_names, - registered_stratifications, - registered_stratification_names, - missing_stratifications, - unused_stratifications, - ) - else: - # Initialize a completely empty dataframe - df = pd.DataFrame() - self._raw_results[measure] = df + self._raw_results[measure] = observation.results_initializer( + event_requested_stratification_names, registered_stratifications + ) + registered_stratification_names = set( + stratification.name for stratification in registered_stratifications + ) + unused_stratifications = registered_stratification_names - used_stratifications if unused_stratifications: self.logger.info( "The following stratifications are registered but not used by any " f"observers: \n{sorted(list(unused_stratifications))}" ) + missing_stratifications = used_stratifications - registered_stratification_names if missing_stratifications: - # Sort by observer/measure and then by missing stratifiction - sorted_missing = { - key: sorted(list(missing_stratifications[key])) - for key in sorted(missing_stratifications) - } raise ValueError( "The following observers are requested to be stratified by " - f"stratifications that are not registered: \n{sorted_missing}" + f"stratifications that are not registered: \n{sorted(list(missing_stratifications))}" ) def on_time_step_prepare(self, event: Event) -> None: @@ -330,56 +310,6 @@ def _get_stratifications( # Makes sure measure identifiers have fields in the same relative order. return tuple(sorted(stratifications)) - @staticmethod - def _initialize_stratified_results( - measure: str, - all_requested_stratification_names: List[str], - registered_stratifications: List[Stratification], - registered_stratification_names: Set[str], - missing_stratifications: Dict[str, Set[str]], - unused_stratifications: Set[str], - ) -> Tuple[pd.DataFrame, Set[str]]: - all_requested_stratification_names = set(all_requested_stratification_names) - - # Batch missing stratifications - observer_missing_stratifications = all_requested_stratification_names.difference( - registered_stratification_names - ) - if observer_missing_stratifications: - missing_stratifications[measure] = observer_missing_stratifications - - # Remove stratifications from the running list of unused stratifications - unused_stratifications = unused_stratifications.difference( - all_requested_stratification_names - ) - - # Set up the complete index of all used stratifications - requested_and_registered_stratifications = [ - stratification - for stratification in registered_stratifications - if stratification.name in all_requested_stratification_names - ] - stratification_values = { - stratification.name: stratification.categories - for stratification in requested_and_registered_stratifications - } - if stratification_values: - stratification_names = list(stratification_values.keys()) - df = pd.DataFrame( - list(itertools.product(*stratification_values.values())), - columns=stratification_names, - ).astype(CategoricalDtype) - else: - # We are aggregating the entire population so create a single-row index - stratification_names = ["stratification"] - df = pd.DataFrame(["all"], columns=stratification_names).astype(CategoricalDtype) - - # Initialize a zeros dataframe - df[VALUE_COLUMN] = 0.0 - df = df.set_index(stratification_names) - - return df, unused_stratifications - def _add_resources(self, target: List[str], target_type: SourceType) -> None: if len(target) == 0: return # do nothing on empty lists diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index 2bfe85914..0ef17de5c 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -1,12 +1,18 @@ from __future__ import annotations +import itertools from abc import ABC from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import pandas as pd +from pandas.api.types import CategoricalDtype from pandas.core.groupby import DataFrameGroupBy +from vivarium.framework.results.stratification import Stratification + +VALUE_COLUMN = "value" + @dataclass class BaseObservation(ABC): @@ -15,6 +21,7 @@ class BaseObservation(ABC): - `name`: name of the observation and is the measure it is observing - `pop_filter`: a filter that is applied to the population before the observation is made - `when`: the phase that the observation is registered to + - `results_initializer`: method that initializes the results - `results_gatherer`: method that gathers the new observation results - `results_updater`: method that updates the results with new observations - `results_formatter`: method that formats the results @@ -23,6 +30,7 @@ class BaseObservation(ABC): name: str pop_filter: str when: str + results_initializer: Callable[..., pd.DataFrame] results_gatherer: Callable[..., pd.DataFrame] results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame] @@ -53,12 +61,21 @@ def __init__( name=name, pop_filter=pop_filter, when=when, + results_initializer=self.initialize_results, results_gatherer=results_gatherer, results_updater=results_updater, results_formatter=results_formatter, stratifications=None, ) + @staticmethod + def initialize_results( + requested_stratification_names: set[str], + registered_stratifications: List[Stratification], + ) -> pd.DataFrame: + """Initialize an empty dataframe.""" + return pd.DataFrame() + class StratifiedObservation(BaseObservation): """Container class for managing stratified observations. @@ -88,6 +105,7 @@ def __init__( name=name, pop_filter=pop_filter, when=when, + results_initializer=self.initialize_results, results_gatherer=self.gather_results, results_updater=results_updater, results_formatter=results_formatter, @@ -96,6 +114,40 @@ def __init__( self.aggregator_sources = aggregator_sources self.aggregator = aggregator + @staticmethod + def initialize_results( + requested_stratification_names: set[str], + registered_stratifications: List[Stratification], + ) -> pd.DataFrame: + """Initialize a dataframe of 0s with complete set of stratifications as the index.""" + + # Set up the complete index of all used stratifications + requested_and_registered_stratifications = [ + stratification + for stratification in registered_stratifications + if stratification.name in requested_stratification_names + ] + stratification_values = { + stratification.name: stratification.categories + for stratification in requested_and_registered_stratifications + } + if stratification_values: + stratification_names = list(stratification_values.keys()) + df = pd.DataFrame( + list(itertools.product(*stratification_values.values())), + columns=stratification_names, + ).astype(CategoricalDtype) + else: + # We are aggregating the entire population so create a single-row index + stratification_names = ["stratification"] + df = pd.DataFrame(["all"], columns=stratification_names).astype(CategoricalDtype) + + # Initialize a zeros dataframe + df[VALUE_COLUMN] = 0.0 + df = df.set_index(stratification_names) + + return df + def gather_results( self, pop_groups: DataFrameGroupBy, diff --git a/tests/framework/results/test_manager.py b/tests/framework/results/test_manager.py index 769b82bef..1e3e0191e 100644 --- a/tests/framework/results/test_manager.py +++ b/tests/framework/results/test_manager.py @@ -414,10 +414,7 @@ def test_observers_with_missing_stratifications_fail(): """ components = [QuidditchWinsObserver(), HousePointsObserver(), Hogwarts()] - expected_missing = { # NOTE: keep in alphabetical order - "house_points": ["power_level_group", "student_house"], - "quidditch_wins": ["familiar"], - } + expected_missing = ["familiar", "power_level_group", "student_house"] expected_log_msg = re.escape( "The following observers are requested to be stratified by stratifications " f"that are not registered: \n{expected_missing}" @@ -432,10 +429,18 @@ def test_unused_stratifications_are_logged(caplog): but never actually used by an Observer The HogwartsResultsStratifier registers "student_house", "familiar", and - "power_level" stratifiers. However, we will only use the HousePointsObserver - component which only requests to be stratified by "student_house" and "power_level" + "power_level_group" stratifiers. However, we will only use the QuidditchWinsObserver + which only uses "familiar" and the MagicalAttributesObserver which only uses + "power_level_group". We would thus expect only "student_house" to be logged + as an unused stratification. + """ - components = [HousePointsObserver(), Hogwarts(), HogwartsResultsStratifier()] + components = [ + Hogwarts(), + HogwartsResultsStratifier(), + QuidditchWinsObserver(), + MagicalAttributesObserver(), + ] InteractiveContext(configuration=HARRY_POTTER_CONFIG, components=components) log_split = caplog.text.split( @@ -444,7 +449,7 @@ def test_unused_stratifications_are_logged(caplog): # Check that the log message is present and only exists one time assert len(log_split) == 2 # Check that the log message contains the expected Stratifications - assert "['familiar']" in log_split[1] + assert "['student_house']" in log_split[1] def test_stratified_observation_results():