From 2dc4870d264b8a745612d1765a563a13f09a49f7 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:08:36 -0600 Subject: [PATCH] implement to_observe method to Observations (#453) --- .../framework/lookup/interpolation.py | 4 +- src/vivarium/framework/results/context.py | 22 ++-- src/vivarium/framework/results/interface.py | 18 +++- src/vivarium/framework/results/manager.py | 12 +-- src/vivarium/framework/results/observation.py | 43 ++++++-- tests/framework/results/test_context.py | 100 +++++++++++++----- 6 files changed, 149 insertions(+), 50 deletions(-) diff --git a/src/vivarium/framework/lookup/interpolation.py b/src/vivarium/framework/lookup/interpolation.py index fe12c621b..02367bb9a 100644 --- a/src/vivarium/framework/lookup/interpolation.py +++ b/src/vivarium/framework/lookup/interpolation.py @@ -109,7 +109,9 @@ def __call__(self, interpolants: pd.DataFrame) -> pd.DataFrame: ) if self.categorical_parameters: - sub_tables = interpolants.groupby(list(self.categorical_parameters)) + sub_tables = interpolants.groupby( + list(self.categorical_parameters), observed=False + ) else: sub_tables = [(None, interpolants)] # specify some numeric type for columns, so they won't be objects but diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 1d3b3977f..5128c25a1 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -7,6 +7,7 @@ from pandas.core.groupby import DataFrameGroupBy from vivarium.framework.engine import Builder +from vivarium.framework.event import Event from vivarium.framework.results.exceptions import ResultsConfigurationError from vivarium.framework.results.observation import BaseObservation from vivarium.framework.results.stratification import Stratification @@ -137,9 +138,9 @@ def register_observation( already_used = None if self.observations: # NOTE: self.observations is a list where each item is a dictionary - # of the form {event_name: {(pop_filter, stratifications): List[Observation]}}. + # of the form {lifecycle_phase: {(pop_filter, stratifications): List[Observation]}}. # We use a triple-nested for loop to iterate over only the list of Observations - # (i.e. we do not need the event_name, pop_filter, or stratifications). + # (i.e. we do not need the lifecycle_phase, pop_filter, or stratifications). for observation_details in self.observations.values(): for observations in observation_details.values(): for observation in observations: @@ -155,7 +156,7 @@ def register_observation( ].append(observation) def gather_results( - self, population: pd.DataFrame, event_name: str + self, population: pd.DataFrame, lifecycle_phase: str, event: Event ) -> Generator[ Tuple[ Optional[pd.DataFrame], @@ -171,7 +172,7 @@ def gather_results( population = stratification(population) for (pop_filter, stratifications), observations in self.observations[ - event_name + lifecycle_phase ].items(): # Results production can be simplified to # filter -> groupby -> aggregate in all situations we've seen. @@ -180,14 +181,13 @@ def gather_results( yield None, None, None else: if stratifications is None: - for observation in observations: - df = observation.results_gatherer(filtered_pop) - yield df, observation.name, observation.results_updater + pop = filtered_pop else: - pop_groups = self._get_groups(stratifications, filtered_pop) - for observation in observations: - aggregates = observation.results_gatherer(pop_groups, stratifications) - yield aggregates, observation.name, observation.results_updater + pop = self._get_groups(stratifications, filtered_pop) + for observation in observations: + yield observation.observe( + event, pop, stratifications + ), observation.name, observation.results_updater @staticmethod def _filter_population(population: pd.DataFrame, pop_filter: str) -> pd.DataFrame: diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index c69d50e33..e251271cf 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -4,9 +4,9 @@ import pandas as pd +from vivarium.framework.event import Event from vivarium.framework.results.observation import ( AddingObservation, - BaseObservation, ConcatenatingObservation, StratifiedObservation, UnstratifiedObservation, @@ -170,6 +170,7 @@ def register_stratified_observation( excluded_stratifications: List[str] = [], aggregator_sources: Optional[List[str]] = None, aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]] = len, + to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Provide the results system all the information it needs to perform a stratified observation. @@ -202,6 +203,8 @@ def register_stratified_observation( A list of population view columns to be used in the aggregator. aggregator A function that computes the quantity for the observation. + to_observe + A function that determines whether to perform an observation on this Event. Returns ------ @@ -222,6 +225,7 @@ def register_stratified_observation( excluded_stratifications=excluded_stratifications, aggregator_sources=aggregator_sources, aggregator=aggregator, + to_observe=to_observe, ) @staticmethod @@ -253,6 +257,7 @@ def register_unstratified_observation( results_formatter: Callable[ [str, pd.DataFrame], pd.DataFrame ] = lambda measure, results: results, + to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Provide the results system all the information it needs to perform a stratified observation. @@ -287,6 +292,8 @@ def register_unstratified_observation( A list of population view columns to be used in the aggregator. aggregator A function that computes the quantity for the observation. + to_observe + A function that determines whether to perform an observation on this Event. Returns ------ @@ -308,6 +315,7 @@ def register_unstratified_observation( results_updater=results_updater, results_gatherer=results_gatherer, results_formatter=results_formatter, + to_observe=to_observe, ) def register_adding_observation( @@ -324,6 +332,7 @@ def register_adding_observation( excluded_stratifications: List[str] = [], aggregator_sources: Optional[List[str]] = None, aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]] = len, + to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Provide the results system all the information it needs to perform the observation. @@ -353,6 +362,8 @@ def register_adding_observation( A list of population view columns to be used in the aggregator. aggregator A function that computes the quantity for the observation. + to_observe + A function that determines whether to perform an observation on this Event. Returns ------ @@ -372,6 +383,7 @@ def register_adding_observation( excluded_stratifications=excluded_stratifications, aggregator_sources=aggregator_sources, aggregator=aggregator, + to_observe=to_observe, ) def register_concatenating_observation( @@ -384,6 +396,7 @@ def register_concatenating_observation( results_formatter: Callable[ [str, pd.DataFrame], pd.DataFrame ] = lambda measure, results: results, + to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Provide the results system all the information it needs to perform the observation. @@ -403,6 +416,8 @@ def register_concatenating_observation( A list of the value pipelines that are required by either the pop_filter or the aggregator. results_formatter A function that formats the observation results. + to_observe + A function that determines whether to perform an observation on this Event. Returns ------ @@ -419,4 +434,5 @@ def register_concatenating_observation( requires_values=requires_values, results_formatter=results_formatter, included_columns=included_columns, + to_observe=to_observe, ) diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index 3ceafd244..b07d04106 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -54,9 +54,9 @@ def get_results(self) -> Dict[str, pd.DataFrame]: """Return the measure-specific formatted results in a dictionary. NOTE: self._results_context.observations is a list where each item is a dictionary - of the form {event_name: {(pop_filter, stratifications): List[Observation]}}. + of the form {lifecycle_phase: {(pop_filter, stratifications): List[Observation]}}. We use a triple-nested for loop to iterate over only the list of Observations - (i.e. we do not need the event_name, pop_filter, or stratifications). + (i.e. we do not need the lifecycle_phase, pop_filter, or stratifications). """ formatted = {} for observation_details in self._results_context.observations.values(): @@ -91,11 +91,11 @@ def on_post_setup(self, _: Event) -> None: registered_stratifications = self._results_context.stratifications used_stratifications = set() - for event_name in self._results_context.observations: + for lifecycle_phase in self._results_context.observations: for ( _pop_filter, event_requested_stratification_names, - ), observations in self._results_context.observations[event_name].items(): + ), observations in self._results_context.observations[lifecycle_phase].items(): if event_requested_stratification_names is not None: used_stratifications |= set(event_requested_stratification_names) for observation in observations: @@ -132,14 +132,14 @@ def on_time_step_cleanup(self, event: Event) -> None: def on_collect_metrics(self, event: Event) -> None: self.gather_results("collect_metrics", event) - def gather_results(self, event_name: str, event: Event) -> None: + def gather_results(self, lifecycle_phase: str, event: Event) -> None: """Update the existing results with new results. Any columns in the results group that are not already in the existing results are initialized with 0.0. """ population = self._prepare_population(event) for results_group, measure, updater in self._results_context.gather_results( - population, event_name + population, lifecycle_phase, event ): if results_group is not None and measure is not None and updater is not None: self._raw_results[measure] = updater( diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index 7958ba2c9..f43543d5c 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -3,12 +3,13 @@ import itertools from abc import ABC from dataclasses import dataclass -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import pandas as pd from pandas.api.types import CategoricalDtype from pandas.core.groupby import DataFrameGroupBy +from vivarium.framework.event import Event from vivarium.framework.results.stratification import Stratification VALUE_COLUMN = "value" @@ -25,6 +26,7 @@ class BaseObservation(ABC): - `results_gatherer`: method that gathers the new observation results - `results_updater`: method that updates the results with new observations - `results_formatter`: method that formats the results + - `to_observe`: method that determines whether to observe an event """ name: str @@ -35,6 +37,21 @@ class BaseObservation(ABC): results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame] stratifications: Optional[Tuple[str]] + to_observe: Callable[[Event], bool] + + def observe( + self, + event: Event, + df: Union[pd.DataFrame, DataFrameGroupBy], + stratifications: Optional[tuple[str, ...]], + ) -> Optional[pd.DataFrame]: + if not self.to_observe(event): + return None + else: + if stratifications is None: + return self.results_gatherer(df) + else: + return self.results_gatherer(df, stratifications) class UnstratifiedObservation(BaseObservation): @@ -46,6 +63,7 @@ class UnstratifiedObservation(BaseObservation): - `results_gatherer`: method that gathers the new observation results - `results_updater`: method that updates the results with new observations - `results_formatter`: method that formats the results + - `to_observe`: method that determines whether to observe an event """ def __init__( @@ -56,6 +74,7 @@ def __init__( results_gatherer: Callable[[pd.DataFrame], pd.DataFrame], results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], + to_observe: Callable[[Event], bool] = lambda event: True, ): super().__init__( name=name, @@ -66,12 +85,13 @@ def __init__( results_updater=results_updater, results_formatter=results_formatter, stratifications=None, + to_observe=to_observe, ) @staticmethod def initialize_results( requested_stratification_names: set[str], - registered_stratifications: List[Stratification], + registered_stratifications: list[Stratification], ) -> pd.DataFrame: """Initialize an empty dataframe.""" return pd.DataFrame() @@ -88,6 +108,7 @@ class StratifiedObservation(BaseObservation): - `stratifications`: a tuple of columns for the observation to stratify by - `aggregator_sources`: a list of the columns to observe - `aggregator`: a method that aggregates the `aggregator_sources` + - `to_observe`: method that determines whether to observe an event """ def __init__( @@ -98,18 +119,20 @@ def __init__( results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], stratifications: Tuple[str, ...], - aggregator_sources: Optional[List[str]], + aggregator_sources: Optional[list[str]], aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], + to_observe: Callable[[Event], bool] = lambda event: True, ): super().__init__( name=name, pop_filter=pop_filter, when=when, results_initializer=self.initialize_results, - results_gatherer=self.gather_results, + results_gatherer=self.results_gatherer, results_updater=results_updater, results_formatter=results_formatter, stratifications=stratifications, + to_observe=to_observe, ) self.aggregator_sources = aggregator_sources self.aggregator = aggregator @@ -150,7 +173,7 @@ def initialize_results( return df - def gather_results( + def results_gatherer( self, pop_groups: DataFrameGroupBy, stratifications: Tuple[str, ...], @@ -203,6 +226,7 @@ class AddingObservation(StratifiedObservation): - `stratifications`: a tuple of columns for the observation to stratify by - `aggregator_sources`: a list of the columns to observe - `aggregator`: a method that aggregates the `aggregator_sources` + - `to_observe`: method that determines whether to observe an event """ def __init__( @@ -214,6 +238,7 @@ def __init__( stratifications: Tuple[str, ...], aggregator_sources: Optional[List[str]], aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], + to_observe: Callable[[Event], bool] = lambda event: True, ): super().__init__( name=name, @@ -224,6 +249,7 @@ def __init__( stratifications=stratifications, aggregator_sources=aggregator_sources, aggregator=aggregator, + to_observe=to_observe, ) @staticmethod @@ -252,6 +278,7 @@ class ConcatenatingObservation(UnstratifiedObservation): - `when`: the phase that the observation is registered to - `included_columns`: the columns to include in the observation - `results_formatter`: method that formats the results + - `to_observe`: method that determines whether to observe an event """ def __init__( @@ -261,14 +288,16 @@ def __init__( when: str, included_columns: List[str], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], + to_observe: Callable[[Event], bool] = lambda event: True, ): super().__init__( name=name, pop_filter=pop_filter, when=when, - results_gatherer=self.gather_results, + results_gatherer=self.results_gatherer, results_updater=self.concatenate_results, results_formatter=results_formatter, + to_observe=to_observe, ) self.included_columns = included_columns @@ -280,5 +309,5 @@ def concatenate_results( return new_observations return pd.concat([existing_results, new_observations], axis=0).reset_index(drop=True) - def gather_results(self, pop: pd.DataFrame) -> pd.DataFrame: + def results_gatherer(self, pop: pd.DataFrame) -> pd.DataFrame: return pop[self.included_columns] diff --git a/tests/framework/results/test_context.py b/tests/framework/results/test_context.py index 376c05f70..d2f47c280 100644 --- a/tests/framework/results/test_context.py +++ b/tests/framework/results/test_context.py @@ -17,6 +17,7 @@ sorting_hat_vector, verify_stratification_added, ) +from vivarium.framework.event import Event from vivarium.framework.results import VALUE_COLUMN from vivarium.framework.results.context import ResultsContext from vivarium.framework.results.observation import ( @@ -30,6 +31,12 @@ def _aggregate_state_person_time(x: pd.DataFrame) -> float: return len(x) * (28 / 365.25) +@pytest.fixture +def mocked_event(mocker) -> Event: + event: Event = mocker.Mock(spec=Event) + return event + + @pytest.mark.parametrize( "name, sources, categories, mapper, is_vectorized", [ @@ -203,7 +210,7 @@ def test_register_observation_duplicate_name_raises(): ], ) def test_adding_observation_gather_results( - pop_filter, aggregator_sources, aggregator, stratifications + pop_filter, aggregator_sources, aggregator, stratifications, mocked_event ): """Test cases where every stratification is in gather_results. Checks for existence and correctness of results""" @@ -217,7 +224,7 @@ def test_adding_observation_gather_results( population["event_time"] = pd.Timestamp(year=2045, month=1, day=1, hour=12) + timedelta( days=28 ) - event_name = "collect_metrics" + lifecycle_phase = "collect_metrics" # Set up stratifications if "house" in stratifications: @@ -231,7 +238,7 @@ def test_adding_observation_gather_results( aggregator_sources=aggregator_sources, aggregator=aggregator, stratifications=tuple(stratifications), - when=event_name, + when=lifecycle_phase, results_formatter=lambda: None, ) @@ -250,7 +257,9 @@ def test_adding_observation_gather_results( ) i = 0 - for result, _measure, _updater in ctx.gather_results(population, event_name): + for result, _measure, _updater in ctx.gather_results( + population, lifecycle_phase, mocked_event + ): assert all( math.isclose(actual_result, expected_result, rel_tol=0.0001) for actual_result in result.values @@ -259,7 +268,7 @@ def test_adding_observation_gather_results( assert i == 1 -def test_concatenating_observation_gather_results(): +def test_concatenating_observation_gather_results(mocked_event): ctx = ResultsContext() @@ -272,14 +281,14 @@ def test_concatenating_observation_gather_results(): days=28 ) - event_name = "collect_metrics" + lifecycle_phase = "collect_metrics" pop_filter = "house=='hufflepuff'" included_cols = ["event_time", "familiar", "house"] ctx.register_observation( observation_type=ConcatenatingObservation, name="foo", pop_filter=pop_filter, - when=event_name, + when=lifecycle_phase, included_columns=included_cols, results_formatter=lambda _, __: pd.DataFrame(), ) @@ -287,7 +296,9 @@ def test_concatenating_observation_gather_results(): filtered_pop = population.query(pop_filter) i = 0 - for result, _measure, _updater in ctx.gather_results(population, event_name): + for result, _measure, _updater in ctx.gather_results( + population, lifecycle_phase, mocked_event + ): assert result.equals(filtered_pop[included_cols]) i += 1 assert i == 1 @@ -325,7 +336,7 @@ def test_concatenating_observation_gather_results(): ], ) def test_gather_results_partial_stratifications_in_results( - name, pop_filter, aggregator_sources, aggregator, stratifications + name, pop_filter, aggregator_sources, aggregator, stratifications, mocked_event ): """Test cases where not all stratifications are observed for gather_results. This looks for existence of unobserved stratifications and ensures their values are 0""" @@ -343,7 +354,7 @@ def test_gather_results_partial_stratifications_in_results( # Remove an entire category from a stratification population = population[population["familiar"] != "unladen_swallow"].reset_index() - event_name = "collect_metrics" + lifecycle_phase = "collect_metrics" # Set up stratifications if "house" in stratifications: @@ -358,17 +369,19 @@ def test_gather_results_partial_stratifications_in_results( aggregator_sources=aggregator_sources, aggregator=aggregator, stratifications=tuple(stratifications), - when=event_name, + when=lifecycle_phase, results_formatter=lambda: None, ) - for results, _measure, _formatter in ctx.gather_results(population, event_name): + for results, _measure, _formatter in ctx.gather_results( + population, lifecycle_phase, mocked_event + ): unladen_results = results.reset_index().query('familiar=="unladen_swallow"') assert len(unladen_results) > 0 assert (unladen_results[VALUE_COLUMN] == 0).all() -def test_gather_results_with_empty_pop_filter(): +def test_gather_results_with_empty_pop_filter(mocked_event): """Test case where pop_filter filters to an empty population. gather_results should return None. """ @@ -377,7 +390,7 @@ def test_gather_results_with_empty_pop_filter(): # Generate population DataFrame population = BASE_POPULATION.copy() - event_name = "collect_metrics" + lifecycle_phase = "collect_metrics" ctx.register_observation( observation_type=AddingObservation, name="wizard_count", @@ -385,22 +398,24 @@ def test_gather_results_with_empty_pop_filter(): aggregator_sources=[], aggregator=len, stratifications=tuple(), - when=event_name, + when=lifecycle_phase, results_formatter=lambda: None, ) - for result, _measure, _updater in ctx.gather_results(population, event_name): + for result, _measure, _updater in ctx.gather_results( + population, lifecycle_phase, mocked_event + ): assert not result -def test_gather_results_with_no_stratifications(): +def test_gather_results_with_no_stratifications(mocked_event): """Test case where we have no stratifications. gather_results should return one value.""" ctx = ResultsContext() # Generate population DataFrame population = BASE_POPULATION.copy() - event_name = "collect_metrics" + lifecycle_phase = "collect_metrics" ctx.register_observation( observation_type=AddingObservation, name="wizard_count", @@ -408,7 +423,7 @@ def test_gather_results_with_no_stratifications(): aggregator_sources=None, aggregator=len, stratifications=tuple(), - when=event_name, + when=lifecycle_phase, results_formatter=lambda: None, ) @@ -417,21 +432,23 @@ def test_gather_results_with_no_stratifications(): len( list( result - for result, _measure, _updater in ctx.gather_results(population, event_name) + for result, _measure, _updater in ctx.gather_results( + population, lifecycle_phase, mocked_event + ) ) ) == 1 ) -def test_bad_aggregator_stratification(): +def test_bad_aggregator_stratification(mocked_event): """Test if an exception gets raised when a stratification that doesn't exist is attempted to be used, as expected.""" ctx = ResultsContext() # Generate population DataFrame population = BASE_POPULATION.copy() - event_name = "collect_metrics" + lifecycle_phase = "collect_metrics" # Set up stratifications ctx.add_stratification("house", ["house"], CATEGORIES, None, True) @@ -443,12 +460,14 @@ def test_bad_aggregator_stratification(): aggregator_sources=[], aggregator=sum, stratifications=("house", "height"), # `height` is not a stratification - when=event_name, + when=lifecycle_phase, results_formatter=lambda: None, ) with pytest.raises(KeyError, match="height"): - for result, _measure, _updater in ctx.gather_results(population, event_name): + for result, _measure, _updater in ctx.gather_results( + population, lifecycle_phase, mocked_event + ): print(result) @@ -504,3 +523,36 @@ def test__get_groups(stratifications, values): key, val = item assert key == "all" assert val.equals(BASE_POPULATION.index) + + +def test_to_observe(mocked_event, mocker): + """Test that to_observe can be used to turn off observations""" + ctx = ResultsContext() + + # Generate population DataFrame + population = BASE_POPULATION.copy() + + lifecycle_phase = "collect_metrics" + ctx.register_observation( + observation_type=AddingObservation, + name="wizard_count", + pop_filter="house == 'hufflepuff'", + aggregator_sources=[], + aggregator=len, + stratifications=tuple(), + when=lifecycle_phase, + results_formatter=lambda: None, + ) + + for result, _measure, _updater in ctx.gather_results( + population, lifecycle_phase, mocked_event + ): + assert not result.empty + + # Extract the observation from the context and patch it to not observe + observation = list(ctx.observations["collect_metrics"].values())[0][0] + mocker.patch.object(observation, "to_observe", return_value=False) + for result, _measure, _updater in ctx.gather_results( + population, lifecycle_phase, mocked_event + ): + assert not result