From a549fee6105a52347bf821ad791f848d44f09870 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:08:38 -0600 Subject: [PATCH] Feature/sbachmei/refactor observation registrations and other minor updates (#437) --- src/vivarium/framework/results/context.py | 110 +---------- src/vivarium/framework/results/interface.py | 74 +++++--- src/vivarium/framework/results/manager.py | 128 ++++--------- src/vivarium/framework/results/observation.py | 4 +- src/vivarium/framework/results/observer.py | 1 + tests/framework/results/helpers.py | 36 ++-- tests/framework/results/test_context.py | 171 ++++++------------ tests/framework/results/test_interface.py | 66 +++---- tests/framework/results/test_manager.py | 66 ++++++- 9 files changed, 273 insertions(+), 383 deletions(-) diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index c0c3eef67..78f933af7 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -1,19 +1,14 @@ from __future__ import annotations from collections import defaultdict -from typing import Callable, Generator, List, Optional, Tuple, Union +from typing import Callable, Generator, List, Optional, Tuple, Type, Union import pandas as pd from pandas.core.groupby import DataFrameGroupBy from vivarium.framework.engine import Builder from vivarium.framework.results.exceptions import ResultsConfigurationError -from vivarium.framework.results.observation import ( - AddingObservation, - ConcatenatingObservation, - StratifiedObservation, - UnstratifiedObservation, -) +from vivarium.framework.results.observation import BaseObservation from vivarium.framework.results.stratification import Stratification @@ -100,93 +95,18 @@ def add_stratification( stratification = Stratification(name, sources, categories, mapper, is_vectorized) self.stratifications.append(stratification) - def register_stratified_observation( + def register_observation( self, + observation_type: Type[BaseObservation], name: str, pop_filter: str, when: str, - results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], - results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], - additional_stratifications: List[str], - excluded_stratifications: List[str], - aggregator_sources: Optional[List[str]], - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], + **kwargs, ) -> None: - stratifications = self._get_stratifications( - additional_stratifications, excluded_stratifications - ) - observation = StratifiedObservation( - name=name, - pop_filter=pop_filter, - when=when, - results_updater=results_updater, - results_formatter=results_formatter, - stratifications=stratifications, - aggregator_sources=aggregator_sources, - aggregator=aggregator, - ) - self.observations[when][(pop_filter, stratifications)].append(observation) - - def register_unstratified_observation( - self, - name: str, - pop_filter: str, - when: str, - results_gatherer: Callable[[pd.DataFrame], pd.DataFrame], - results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], - results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], - ) -> None: - observation = UnstratifiedObservation( - name=name, - pop_filter=pop_filter, - when=when, - results_gatherer=results_gatherer, - results_updater=results_updater, - results_formatter=results_formatter, - ) - self.observations[when][(pop_filter, None)].append(observation) - - def register_adding_observation( - self, - name: str, - pop_filter: str, - when: str, - results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], - additional_stratifications: List[str], - excluded_stratifications: List[str], - aggregator_sources: Optional[List[str]], - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], - ) -> None: - stratifications = self._get_stratifications( - additional_stratifications, excluded_stratifications - ) - observation = AddingObservation( - name=name, - pop_filter=pop_filter, - when=when, - results_formatter=results_formatter, - stratifications=stratifications, - aggregator_sources=aggregator_sources, - aggregator=aggregator, - ) - self.observations[when][(pop_filter, stratifications)].append(observation) - - def register_concatenating_observation( - self, - name: str, - pop_filter: str, - when: str, - included_columns: List[str], - results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], - ) -> None: - observation = ConcatenatingObservation( - name=name, - pop_filter=pop_filter, - when=when, - included_columns=included_columns, - results_formatter=results_formatter, - ) - self.observations[when][(pop_filter, None)].append(observation) + observation = observation_type(name=name, pop_filter=pop_filter, when=when, **kwargs) + self.observations[observation.when][ + (observation.pop_filter, observation.stratifications) + ].append(observation) def gather_results( self, population: pd.DataFrame, event_name: str @@ -223,18 +143,6 @@ def gather_results( aggregates = observation.results_gatherer(pop_groups, stratifications) yield aggregates, observation.name, observation.results_updater - def _get_stratifications( - self, - additional_stratifications: List[str] = [], - excluded_stratifications: List[str] = [], - ) -> Tuple[str, ...]: - stratifications = list( - set(self.default_stratifications) - set(excluded_stratifications) - | set(additional_stratifications) - ) - # Makes sure measure identifiers have fields in the same relative order. - return tuple(sorted(stratifications)) - @staticmethod def _filter_population(population: pd.DataFrame, pop_filter: str) -> pd.DataFrame: return population.query(pop_filter) if pop_filter else population diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index 1b59c3729..69f8ac842 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -4,30 +4,22 @@ import pandas as pd +from vivarium.framework.results.observation import ( + AddingObservation, + BaseObservation, + ConcatenatingObservation, + StratifiedObservation, + UnstratifiedObservation, +) + if TYPE_CHECKING: # cyclic import from vivarium.framework.results.manager import ResultsManager -def _raise_missing_unstratified_observation_results_gatherer(*args, **kwargs) -> pd.DataFrame: - raise RuntimeError( - "An UnstratifiedObservation has been registered without a `results_gatherer` " - "Callable which is required." - ) - - -def _raise_missing_unstratified_observation_results_updater(*args, **kwargs) -> pd.DataFrame: - raise RuntimeError( - "An UnstratifiedObservation has been registered without a `results_updater` " - "Callable which is required." - ) - - -def _raise_missing_stratified_observation_results_updater(*args, **kwargs) -> pd.DataFrame: - raise RuntimeError( - "A StratifiedObservation has been registered without a `results_updater` " - "Callable which is required." - ) +def _required_function_placeholder(*args, **kwargs) -> pd.DataFrame: + """Placeholder function to indicate that a required function is missing.""" + return pd.DataFrame() class ResultsInterface: @@ -170,7 +162,7 @@ def register_stratified_observation( requires_values: List[str] = [], results_updater: Callable[ [pd.DataFrame, pd.DataFrame], pd.DataFrame - ] = _raise_missing_stratified_observation_results_updater, + ] = _required_function_placeholder, results_formatter: Callable[ [str, pd.DataFrame], pd.DataFrame ] = lambda measure, results: results, @@ -215,7 +207,10 @@ def register_stratified_observation( ------ None """ - self._manager.register_stratified_observation( + self._check_for_required_callables(name, {"results_updater": results_updater}) + self._manager.register_observation( + observation_type=StratifiedObservation, + is_stratified=True, name=name, pop_filter=pop_filter, when=when, @@ -229,6 +224,19 @@ def register_stratified_observation( aggregator=aggregator, ) + @staticmethod + def _check_for_required_callables( + observation_name: str, required_callables: Dict[str, Callable] + ) -> None: + missing = [] + for arg_name, callable in required_callables.items(): + if callable == _required_function_placeholder: + missing.append(arg_name) + if len(missing) > 0: + raise ValueError( + f"Observation '{observation_name}' is missing required callable(s): {missing}" + ) + def register_unstratified_observation( self, name: str, @@ -238,10 +246,10 @@ def register_unstratified_observation( requires_values: List[str] = [], results_gatherer: Callable[ [pd.DataFrame], pd.DataFrame - ] = _raise_missing_unstratified_observation_results_gatherer, + ] = _required_function_placeholder, results_updater: Callable[ [pd.DataFrame, pd.DataFrame], pd.DataFrame - ] = _raise_missing_unstratified_observation_results_updater, + ] = _required_function_placeholder, results_formatter: Callable[ [str, pd.DataFrame], pd.DataFrame ] = lambda measure, results: results, @@ -284,7 +292,14 @@ def register_unstratified_observation( ------ None """ - self._manager.register_unstratified_observation( + required_callables = { + "results_gatherer": results_gatherer, + "results_updater": results_updater, + } + self._check_for_required_callables(name, required_callables) + self._manager.register_observation( + observation_type=UnstratifiedObservation, + is_stratified=False, name=name, pop_filter=pop_filter, when=when, @@ -343,7 +358,10 @@ def register_adding_observation( ------ None """ - self._manager.register_adding_observation( + + self._manager.register_observation( + observation_type=AddingObservation, + is_stratified=True, name=name, pop_filter=pop_filter, when=when, @@ -390,11 +408,15 @@ def register_concatenating_observation( ------ None """ - self._manager.register_concatenating_observation( + included_columns = ["event_time"] + requires_columns + requires_values + self._manager.register_observation( + observation_type=ConcatenatingObservation, + is_stratified=False, name=name, pop_filter=pop_filter, when=when, requires_columns=requires_columns, requires_values=requires_values, results_formatter=results_formatter, + included_columns=included_columns, ) diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index e6a4718da..9e8a18eb6 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -10,6 +10,7 @@ from vivarium.framework.event import Event from vivarium.framework.results.context import ResultsContext +from vivarium.framework.results.observation import ConcatenatingObservation from vivarium.framework.results.stratification import Stratification from vivarium.framework.values import Pipeline from vivarium.manager import Manager @@ -269,115 +270,66 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: **target_kwargs, ) - ####################### - # Observation methods # - ####################### - - def register_stratified_observation( + def register_observation( self, + observation_type, + is_stratified: bool, name: str, pop_filter: str, when: str, requires_columns: List[str], requires_values: List[str], - results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], - results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], - additional_stratifications: List[str], - excluded_stratifications: List[str], - aggregator_sources: Optional[List[str]], - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], - ) -> None: + **kwargs, + ): self.logger.debug(f"Registering observation {name}") - self._warn_check_stratifications(additional_stratifications, excluded_stratifications) - self._results_context.register_stratified_observation( - name=name, - pop_filter=pop_filter, - when=when, - results_updater=results_updater, - results_formatter=results_formatter, - additional_stratifications=additional_stratifications, - excluded_stratifications=excluded_stratifications, - aggregator_sources=aggregator_sources, - aggregator=aggregator, - ) - self._add_resources(requires_columns, SourceType.COLUMN) - self._add_resources(requires_values, SourceType.VALUE) - def register_unstratified_observation( - self, - name: str, - pop_filter: str, - when: str, - requires_columns: List[str], - requires_values: List[str], - results_gatherer: Callable[[pd.DataFrame], pd.DataFrame], - results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], - results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], - ) -> None: - self.logger.debug(f"Registering observation {name}") - self._results_context.register_unstratified_observation( - name=name, - pop_filter=pop_filter, - when=when, - results_gatherer=results_gatherer, - results_updater=results_updater, - results_formatter=results_formatter, - ) - self._add_resources(["event_time"] + requires_columns, SourceType.COLUMN) - self._add_resources(requires_values, SourceType.VALUE) + if is_stratified: + additional_stratifications = kwargs.get("additional_stratifications", []) + excluded_stratifications = kwargs.get("excluded_stratifications", []) + self._warn_check_stratifications( + additional_stratifications, excluded_stratifications + ) + stratifications = self._get_stratifications( + kwargs.get("stratifications", []), + additional_stratifications, + excluded_stratifications, + ) + kwargs["stratifications"] = stratifications + del kwargs["additional_stratifications"] + del kwargs["excluded_stratifications"] - def register_adding_observation( - self, - name: str, - pop_filter: str, - when: str, - requires_columns: List[str], - requires_values: List[str], - results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], - additional_stratifications: List[str], - excluded_stratifications: List[str], - aggregator_sources: Optional[List[str]], - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], - ) -> None: - self.logger.debug(f"Registering observation {name}") - self._warn_check_stratifications(additional_stratifications, excluded_stratifications) - self._results_context.register_adding_observation( - name=name, - pop_filter=pop_filter, - when=when, - results_formatter=results_formatter, - additional_stratifications=additional_stratifications, - excluded_stratifications=excluded_stratifications, - aggregator_sources=aggregator_sources, - aggregator=aggregator, - ) self._add_resources(requires_columns, SourceType.COLUMN) self._add_resources(requires_values, SourceType.VALUE) - def register_concatenating_observation( - self, - name: str, - pop_filter: str, - when: str, - requires_columns: List[str], - requires_values: List[str], - results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], - ) -> None: - self.logger.debug(f"Registering observation {name}") - self._results_context.register_concatenating_observation( + self._results_context.register_observation( + observation_type=observation_type, name=name, pop_filter=pop_filter, when=when, - included_columns=["event_time"] + requires_columns + requires_values, - results_formatter=results_formatter, + **kwargs, ) - self._add_resources(["event_time"] + requires_columns, SourceType.COLUMN) - self._add_resources(requires_values, SourceType.VALUE) ################## # Helper methods # ################## + def _get_stratifications( + self, + stratifications: List[str] = [], + additional_stratifications: List[str] = [], + excluded_stratifications: List[str] = [], + ) -> Tuple[str, ...]: + stratifications = list( + set( + self._results_context.default_stratifications + + stratifications + + additional_stratifications + ) + - set(excluded_stratifications) + ) + # Makes sure measure identifiers have fields in the same relative order. + return tuple(sorted(stratifications)) + @staticmethod def _initialize_stratified_results( measure: str, diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index bb7ef661b..2bfe85914 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -26,6 +26,7 @@ class BaseObservation(ABC): results_gatherer: Callable[..., pd.DataFrame] results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame] + stratifications: Optional[Tuple[str]] class UnstratifiedObservation(BaseObservation): @@ -55,6 +56,7 @@ def __init__( results_gatherer=results_gatherer, results_updater=results_updater, results_formatter=results_formatter, + stratifications=None, ) @@ -89,8 +91,8 @@ def __init__( results_gatherer=self.gather_results, results_updater=results_updater, results_formatter=results_formatter, + stratifications=stratifications, ) - self.stratifications = stratifications self.aggregator_sources = aggregator_sources self.aggregator = aggregator diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 6a3aef36c..3a7a63c82 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -36,6 +36,7 @@ def get_formatter_attributes(self, builder: Builder) -> None: ) +# TODO: Move this property into Observer and get rid of StratifiedObserver class StratifiedObserver(Observer): @property def configuration_defaults(self) -> Dict[str, Any]: diff --git a/tests/framework/results/helpers.py b/tests/framework/results/helpers.py index e32b607ce..aeb36bcc5 100644 --- a/tests/framework/results/helpers.py +++ b/tests/framework/results/helpers.py @@ -202,29 +202,37 @@ def register_observations(self, builder: Builder) -> None: ) -class CatLivesObserver(StratifiedObserver): - """Observer that counts the number of cat lives per house""" +class CatBombObserver(StratifiedObserver): + """Observer that counts the number of feral cats per house""" def register_observations(self, builder: Builder) -> None: - builder.results.register_adding_observation( - name="cat_lives", + builder.results.register_stratified_observation( + name="cat_bomb", pop_filter="familiar=='cat' and tracked==True", requires_columns=["familiar"], + results_updater=self.update_cats, excluded_stratifications=["power_level_group"], aggregator_sources=["student_house"], - aggregator=self.count_lives, + aggregator=len, ) - @staticmethod - def count_lives(group): - return len(group) * 9 + def update_cats(self, existing_df, new_df): + no_cats_mask = existing_df["value"] == 0 + updated_df = existing_df + updated_df.loc[no_cats_mask, "value"] = new_df["value"] + updated_df.loc[~no_cats_mask, "value"] *= new_df["value"] + return updated_df class ValedictorianObserver(Observer): """Observer that records the valedictorian at each time step. All students - have the same exam scores and so the valecdictorian is chosen randomly. + have the same exam scores and so the valedictorian is chosen randomly. """ + def __init__(self): + super().__init__() + self.valedictorians = [] + def register_observations(self, builder: Builder) -> None: builder.results.register_unstratified_observation( name="valedictorian", @@ -233,13 +241,13 @@ def register_observations(self, builder: Builder) -> None: results_updater=self.update_valedictorian, ) - @staticmethod - def choose_valedictorian(df): - valedictorian = RNG.choice(df["student_id"]) + def choose_valedictorian(self, df): + eligible_students = df.loc[~df["student_id"].isin(self.valedictorians), "student_id"] + valedictorian = RNG.choice(eligible_students) + self.valedictorians.append(valedictorian) return df[df["student_id"] == valedictorian] - @staticmethod - def update_valedictorian(_existing_df, new_df): + def update_valedictorian(self, _existing_df, new_df): return new_df diff --git a/tests/framework/results/test_context.py b/tests/framework/results/test_context.py index 37772f0ca..19b8280cf 100644 --- a/tests/framework/results/test_context.py +++ b/tests/framework/results/test_context.py @@ -18,6 +18,10 @@ ) from vivarium.framework.results import VALUE_COLUMN from vivarium.framework.results.context import ResultsContext +from vivarium.framework.results.observation import ( + AddingObservation, + ConcatenatingObservation, +) @pytest.mark.parametrize( @@ -90,121 +94,60 @@ def _aggregate_state_person_time(x: pd.DataFrame) -> float: @pytest.mark.parametrize( - "name, pop_filter, aggregator, additional_stratifications, excluded_stratifications, when", + "kwargs", [ - ( - "living_person_time", - 'alive == "alive" and undead == False', - _aggregate_state_person_time, - [], - [], - "collect_metrics", - ), - ( - "undead_person_time", - "undead == True", - _aggregate_state_person_time, - [], - [], - "time_step__prepare", - ), + { + "name": "living_person_time", + "pop_filter": 'alive == "alive" and undead == False', + "when": "collect_metrics", + }, + { + "name": "undead_person_time", + "pop_filter": "undead == True", + "when": "time_step__prepare", + }, ], ids=["valid_on_collect_metrics", "valid_on_time_step__prepare"], ) -def test_add_observation( - name, pop_filter, aggregator, additional_stratifications, excluded_stratifications, when -): +def test_add_observation(kwargs): ctx = ResultsContext() - ctx._default_stratifications = ["age", "sex"] assert len(ctx.observations) == 0 - ctx.register_adding_observation( - name=name, - pop_filter=pop_filter, - aggregator_sources=[], - aggregator=aggregator, - additional_stratifications=additional_stratifications, - excluded_stratifications=excluded_stratifications, - when=when, - results_formatter=lambda: None, + kwargs["results_formatter"] = lambda: None + kwargs["stratifications"] = tuple() + kwargs["aggregator_sources"] = [] + kwargs["aggregator"] = len + ctx.register_observation( + observation_type=AddingObservation, + **kwargs, ) assert len(ctx.observations) == 1 -@pytest.mark.parametrize( - "name, pop_filter, aggregator, additional_stratifications, excluded_stratifications, when", - [ - ( - "living_person_time", - 'alive == "alive" and undead == False', - _aggregate_state_person_time, - [], - [], - "collect_metrics", - ), - ], - ids=["valid_on_collect_metrics"], -) -def test_double_add_observation( - name, pop_filter, aggregator, additional_stratifications, excluded_stratifications, when -): +def test_double_add_observation(): """Tests a double add of the same stratification, this should result in one additional observation being added to the context.""" ctx = ResultsContext() - ctx._default_stratifications = ["age", "sex"] assert len(ctx.observations) == 0 - ctx.register_adding_observation( - name=name, - pop_filter=pop_filter, - aggregator_sources=[], - aggregator=aggregator, - additional_stratifications=additional_stratifications, - excluded_stratifications=excluded_stratifications, - when=when, - results_formatter=lambda: None, + kwargs = { + "name": "living_person_time", + "pop_filter": 'alive == "alive" and undead == False', + "when": "collect_metrics", + "results_formatter": lambda: None, + "stratifications": tuple(), + "aggregator_sources": [], + "aggregator": len, + } + ctx.register_observation( + observation_type=AddingObservation, + **kwargs, ) - ctx.register_adding_observation( - name=name, - pop_filter=pop_filter, - aggregator_sources=[], - aggregator=aggregator, - additional_stratifications=additional_stratifications, - excluded_stratifications=excluded_stratifications, - when=when, - results_formatter=lambda: None, + ctx.register_observation( + observation_type=AddingObservation, + **kwargs, ) assert len(ctx.observations) == 1 -@pytest.mark.parametrize( - "default_stratifications, additional_stratifications, excluded_stratifications, expected_stratifications", - [ - ([], [], [], ()), - (["age", "sex"], ["handedness"], ["age"], ("sex", "handedness")), - (["age", "sex"], [], ["age", "sex"], ()), - (["age"], [], ["bogus_exclude_column"], ("age",)), - ], - ids=[ - "empty_add_empty_exclude", - "one_add_one_exclude", - "all_defaults_excluded", - "bogus_exclude", - ], -) -def test__get_stratifications( - default_stratifications, - additional_stratifications, - excluded_stratifications, - expected_stratifications, -): - ctx = ResultsContext() - # default_stratifications would normally be set via ResultsInterface.set_default_stratifications() - ctx.default_stratifications = default_stratifications - stratifications = ctx._get_stratifications( - additional_stratifications, excluded_stratifications - ) - assert sorted(stratifications) == sorted(expected_stratifications) - - @pytest.mark.parametrize( "pop_filter, aggregator_sources, aggregator, stratifications", [ @@ -261,13 +204,13 @@ def test_adding_observation_gather_results( ctx.add_stratification("house", ["house"], CATEGORIES, None, True) if "familiar" in stratifications: ctx.add_stratification("familiar", ["familiar"], FAMILIARS, None, True) - ctx.register_adding_observation( + ctx.register_observation( + observation_type=AddingObservation, name="foo", pop_filter=pop_filter, aggregator_sources=aggregator_sources, aggregator=aggregator, - additional_stratifications=stratifications, - excluded_stratifications=[], + stratifications=tuple(stratifications), when=event_name, results_formatter=lambda: None, ) @@ -312,7 +255,8 @@ def test_concatenating_observation_gather_results(): event_name = "collect_metrics" pop_filter = "house=='hufflepuff'" included_cols = ["event_time", "familiar", "house"] - ctx.register_concatenating_observation( + ctx.register_observation( + observation_type=ConcatenatingObservation, name="foo", pop_filter=pop_filter, when=event_name, @@ -386,19 +330,20 @@ def test_gather_results_partial_stratifications_in_results( ctx.add_stratification("house", ["house"], CATEGORIES, None, True) if "familiar" in stratifications: ctx.add_stratification("familiar", ["familiar"], FAMILIARS, None, True) - ctx.register_adding_observation( + + ctx.register_observation( + observation_type=AddingObservation, name=name, pop_filter=pop_filter, aggregator_sources=aggregator_sources, aggregator=aggregator, - additional_stratifications=stratifications, - excluded_stratifications=[], + stratifications=tuple(stratifications), when=event_name, results_formatter=lambda: None, ) for results, _measure, _formatter in ctx.gather_results(population, event_name): - unladen_results = results.loc["unladen_swallow"] + unladen_results = results.reset_index().query('familiar=="unladen_swallow"') assert len(unladen_results) > 0 assert (unladen_results[VALUE_COLUMN] == 0).all() @@ -413,13 +358,13 @@ def test_gather_results_with_empty_pop_filter(): population = BASE_POPULATION.copy() event_name = "collect_metrics" - ctx.register_adding_observation( + ctx.register_observation( + observation_type=AddingObservation, name="wizard_count", pop_filter="house == 'durmstrang'", aggregator_sources=[], aggregator=len, - additional_stratifications=[], - excluded_stratifications=[], + stratifications=tuple(), when=event_name, results_formatter=lambda: None, ) @@ -436,13 +381,13 @@ def test_gather_results_with_no_stratifications(): population = BASE_POPULATION.copy() event_name = "collect_metrics" - ctx.register_adding_observation( + ctx.register_observation( + observation_type=AddingObservation, name="wizard_count", pop_filter="", aggregator_sources=None, aggregator=len, - additional_stratifications=[], - excluded_stratifications=[], + stratifications=tuple(), when=event_name, results_formatter=lambda: None, ) @@ -471,13 +416,13 @@ def test_bad_aggregator_stratification(): # Set up stratifications ctx.add_stratification("house", ["house"], CATEGORIES, None, True) ctx.add_stratification("familiar", ["familiar"], FAMILIARS, None, True) - ctx.register_adding_observation( + ctx.register_observation( + observation_type=AddingObservation, name="this_shouldnt_work", pop_filter="", aggregator_sources=[], aggregator=sum, - additional_stratifications=["house", "height"], # `height` is not a stratification - excluded_stratifications=[], + stratifications=("house", "height"), # `height` is not a stratification when=event_name, results_formatter=lambda: None, ) diff --git a/tests/framework/results/test_interface.py b/tests/framework/results/test_interface.py index 12ba5efcb..5876c2de4 100644 --- a/tests/framework/results/test_interface.py +++ b/tests/framework/results/test_interface.py @@ -1,3 +1,4 @@ +import re from datetime import timedelta from types import MethodType @@ -14,6 +15,29 @@ def _silly_aggregator(_: pd.DataFrame) -> float: return 1.0 +@pytest.mark.parametrize( + ("obs_type", "missing_args"), + [ + ("StratifiedObservation", ["results_updater"]), + ("UnstratifiedObservation", ["results_gatherer", "results_updater"]), + ], +) +def test_register_observation_raises(obs_type, missing_args, mocker): + builder = mocker.Mock() + builder.configuration.stratification.default = [] + mgr = ResultsManager() + mgr.setup(builder) + interface = ResultsInterface(mgr) + match = re.escape( + f"Observation 'some-name' is missing required callable(s): {missing_args}", + ) + with pytest.raises(ValueError, match=match): + if obs_type == "StratifiedObservation": + interface.register_stratified_observation(name="some-name") + if obs_type == "UnstratifiedObservation": + interface.register_unstratified_observation(name="some-name") + + def test_register_stratified_observation(mocker): mgr = ResultsManager() interface = ResultsInterface(mgr) @@ -54,28 +78,6 @@ def test_register_stratified_observation(mocker): assert obs.aggregator_sources is None -def test_register_stratified_observation_raises(mocker): - builder = mocker.Mock() - builder.configuration.stratification.default = [] - mgr = ResultsManager() - mgr.setup(builder) - interface = ResultsInterface(mgr) - with pytest.raises( - RuntimeError, - match=( - "A StratifiedObservation has been registered without a `results_updater` " - "Callable which is required." - ), - ): - interface.register_stratified_observation(name="some-name") - observations = interface._manager._results_context.observations - ((_filter, _stratifications), observation) = list( - observations["collect_metrics"].items() - )[0] - obs = observation[0] - obs.results_updater() - - def test_register_unstratified_observation(mocker): mgr = ResultsManager() interface = ResultsInterface(mgr) @@ -92,7 +94,7 @@ def test_register_unstratified_observation(mocker): requires_columns=["some-column", "some-other-column"], requires_values=["some-value", "some-other-value"], results_gatherer=lambda _: pd.DataFrame(), - results_formatter=lambda _, __: pd.DataFrame(), + results_updater=lambda _, __: pd.DataFrame(), ) observations = interface._manager._results_context.observations assert len(observations) == 1 @@ -233,15 +235,15 @@ def test_unhashable_pipeline(mocker): assert len(interface._manager._results_context.observations) == 0 with pytest.raises(TypeError, match="unhashable"): interface.register_adding_observation( - "living_person_time", - 'alive == "alive" and undead == False', - [], - _silly_aggregator, - [], - [["bad", "unhashable", "thing"]], # unhashable first element - [], - [], - "collect_metrics", + name="living_person_time", + pop_filter='alive == "alive" and undead == False', + when="collect_metrics", + requires_columns=[], + requires_values=[["bad", "unhashable", "thing"]], # unhashable first element + additional_stratifications=[], + excluded_stratifications=[], + aggregator_sources=[], + aggregator=_silly_aggregator, ) diff --git a/tests/framework/results/test_manager.py b/tests/framework/results/test_manager.py index 1672ec7ba..769b82bef 100644 --- a/tests/framework/results/test_manager.py +++ b/tests/framework/results/test_manager.py @@ -19,7 +19,7 @@ POWER_LEVEL_GROUP_LABELS, SOURCES, STUDENT_HOUSES, - CatLivesObserver, + CatBombObserver, ExamScoreObserver, FullyFilteredHousePointsObserver, Hogwarts, @@ -35,9 +35,54 @@ verify_stratification_added, ) from vivarium.framework.results import VALUE_COLUMN +from vivarium.framework.results.context import ResultsContext from vivarium.framework.results.manager import ResultsManager +from vivarium.framework.results.observation import AddingObservation from vivarium.interface.interactive import InteractiveContext + +@pytest.mark.parametrize( + "stratifications, default_stratifications, additional_stratifications, excluded_stratifications, expected_stratifications", + [ + ([], [], [], [], ()), + ( + [], + ["age", "sex"], + ["handedness"], + ["age"], + ("sex", "handedness"), + ), + ([], ["age", "sex"], [], ["age", "sex"], ()), + ([], ["age"], [], ["bogus_exclude_column"], ("age",)), + (["custom"], ["age", "sex"], [], [], ("custom", "age", "sex")), + ], + ids=[ + "empty_add_empty_exclude", + "one_add_one_exclude", + "all_defaults_excluded", + "bogus_exclude", + "custom_stratification", + ], +) +def test__get_stratifications( + stratifications, + default_stratifications, + additional_stratifications, + excluded_stratifications, + expected_stratifications, + mocker, +): + ctx = ResultsContext() + ctx.default_stratifications = default_stratifications + mgr = ResultsManager() + mocker.patch.object(mgr, "_results_context", ctx) + # default_stratifications would normally be set via ResultsInterface.set_default_stratifications() + stratifications = mgr._get_stratifications( + stratifications, additional_stratifications, excluded_stratifications + ) + assert sorted(stratifications) == sorted(expected_stratifications) + + ####################################### # Tests for `register_stratification` # ####################################### @@ -259,7 +304,9 @@ def test_add_observation_nop_stratifications( mgr.logger = logger mgr._results_context.default_stratifications = default - mgr.register_adding_observation( + mgr.register_observation( + observation_type=AddingObservation, + is_stratified=True, name="name", pop_filter='alive == "alive"', aggregator_sources=[], @@ -403,18 +450,21 @@ def test_unused_stratifications_are_logged(caplog): def test_stratified_observation_results(): components = [ Hogwarts(), - CatLivesObserver(), + CatBombObserver(), HogwartsResultsStratifier(), ] sim = InteractiveContext(configuration=HARRY_POTTER_CONFIG, components=components) - assert (sim.get_results()["cat_lives"]["value"] == 0.0).all() + assert (sim.get_results()["cat_bomb"]["value"] == 0.0).all() sim.step() num_familiars = sim.get_population().groupby(["familiar", "student_house"]).apply(len) - expected = num_familiars.loc["cat"] * 9.0 + expected = num_familiars.loc["cat"] ** 1.0 expected.name = "value" - assert expected.sort_values().equals( - sim.get_results()["cat_lives"]["value"].sort_values() - ) + assert expected.sort_values().equals(sim.get_results()["cat_bomb"]["value"].sort_values()) + sim.step() + num_familiars = sim.get_population().groupby(["familiar", "student_house"]).apply(len) + expected = num_familiars.loc["cat"] ** 2.0 + expected.name = "value" + assert expected.sort_values().equals(sim.get_results()["cat_bomb"]["value"].sort_values()) def test_unstratified_observation_results():