diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 78f933af7..2fff1980f 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -103,6 +103,45 @@ def register_observation( when: str, **kwargs, ) -> None: + """Add an observation to the context. + + Parameters + ---------- + observation_type + Class type of the observation to register. + name + Name of the metric to observe and result file. + pop_filter + A Pandas query filter string to filter the population down to the + simulants who should be considered for the observation. + when + String name of the phase of a time-step the observation should happen. + Valid values are: `"time_step__prepare"`, `"time_step"`, + `"time_step__cleanup"`, `"collect_metrics"`. + kwargs + Additional keyword arguments to pass to the observation constructor. + + + Returns + ------ + None + + """ + already_used = None + if self.observations: + # NOTE: self.observations is a list where each item is a dictionary + # of the form {event_name: {(pop_filter, stratifications): List[Observation]}}. + # We use a triple-nested for loop to iterate over only the list of Observations + # (i.e. we do not need the event_name, pop_filter, or stratifications). + for observation_details in self.observations.values(): + for observations in observation_details.values(): + for observation in observations: + if observation.name == name: + already_used = observation + if already_used: + raise ValueError( + f"Observation name '{name}' is already used: {str(already_used)}." + ) observation = observation_type(name=name, pop_filter=pop_filter, when=when, **kwargs) self.observations[observation.when][ (observation.pop_filter, observation.stratifications) diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index aafe6b03d..3ceafd244 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -55,7 +55,7 @@ def get_results(self) -> Dict[str, pd.DataFrame]: NOTE: self._results_context.observations is a list where each item is a dictionary of the form {event_name: {(pop_filter, stratifications): List[Observation]}}. - We use a triple-nested for loop to iterative over only the list of Observations + We use a triple-nested for loop to iterate over only the list of Observations (i.e. we do not need the event_name, pop_filter, or stratifications). """ formatted = {} diff --git a/tests/framework/results/test_context.py b/tests/framework/results/test_context.py index 19b8280cf..dbbcae2cd 100644 --- a/tests/framework/results/test_context.py +++ b/tests/framework/results/test_context.py @@ -109,7 +109,7 @@ def _aggregate_state_person_time(x: pd.DataFrame) -> float: ], ids=["valid_on_collect_metrics", "valid_on_time_step__prepare"], ) -def test_add_observation(kwargs): +def test_register_observation(kwargs): ctx = ResultsContext() assert len(ctx.observations) == 0 kwargs["results_formatter"] = lambda: None @@ -123,29 +123,28 @@ def test_add_observation(kwargs): assert len(ctx.observations) == 1 -def test_double_add_observation(): - """Tests a double add of the same stratification, this should result in one - additional observation being added to the context.""" +def test_register_observation_duplicate_name_raises(): ctx = ResultsContext() - assert len(ctx.observations) == 0 - kwargs = { - "name": "living_person_time", - "pop_filter": 'alive == "alive" and undead == False', - "when": "collect_metrics", - "results_formatter": lambda: None, - "stratifications": tuple(), - "aggregator_sources": [], - "aggregator": len, - } - ctx.register_observation( - observation_type=AddingObservation, - **kwargs, - ) ctx.register_observation( observation_type=AddingObservation, - **kwargs, + name="some-observation-name", + pop_filter="some-pop-filter", + when="some-when", + results_formatter=lambda df: df, + stratifications=(), + aggregator_sources=[], + aggregator=len, ) - assert len(ctx.observations) == 1 + with pytest.raises( + ValueError, match="Observation name 'some-observation-name' is already used: " + ): + # register a different observation but w/ the same name + ctx.register_observation( + observation_type=ConcatenatingObservation, + name="some-observation-name", + pop_filter="some-other-pop-filter", + when="some-other-when", + ) @pytest.mark.parametrize( diff --git a/tests/helpers.py b/tests/helpers.py index ba7e90ff9..a342e0803 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -51,7 +51,7 @@ def setup(self, builder: Builder) -> None: self.builder_used_for_setup = builder def register_observations(self, builder): - builder.results.register_adding_observation("test", aggregator=self.counter) + builder.results.register_adding_observation(self.name, aggregator=self.counter) def create_lookup_tables(self, builder): return {}