Skip to content

Commit

Permalink
raise when registering an observation name duplicate (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier committed Aug 13, 2024
1 parent 832687c commit f7f2b89
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 22 deletions.
39 changes: 39 additions & 0 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,45 @@ def register_observation(
when: str,
**kwargs,
) -> None:
"""Add an observation to the context.
Parameters
----------
observation_type
Class type of the observation to register.
name
Name of the metric to observe and result file.
pop_filter
A Pandas query filter string to filter the population down to the
simulants who should be considered for the observation.
when
String name of the phase of a time-step the observation should happen.
Valid values are: `"time_step__prepare"`, `"time_step"`,
`"time_step__cleanup"`, `"collect_metrics"`.
kwargs
Additional keyword arguments to pass to the observation constructor.
Returns
------
None
"""
already_used = None
if self.observations:
# NOTE: self.observations is a list where each item is a dictionary
# of the form {event_name: {(pop_filter, stratifications): List[Observation]}}.
# We use a triple-nested for loop to iterate over only the list of Observations
# (i.e. we do not need the event_name, pop_filter, or stratifications).
for observation_details in self.observations.values():
for observations in observation_details.values():
for observation in observations:
if observation.name == name:
already_used = observation
if already_used:
raise ValueError(
f"Observation name '{name}' is already used: {str(already_used)}."
)
observation = observation_type(name=name, pop_filter=pop_filter, when=when, **kwargs)
self.observations[observation.when][
(observation.pop_filter, observation.stratifications)
Expand Down
2 changes: 1 addition & 1 deletion src/vivarium/framework/results/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_results(self) -> Dict[str, pd.DataFrame]:
NOTE: self._results_context.observations is a list where each item is a dictionary
of the form {event_name: {(pop_filter, stratifications): List[Observation]}}.
We use a triple-nested for loop to iterative over only the list of Observations
We use a triple-nested for loop to iterate over only the list of Observations
(i.e. we do not need the event_name, pop_filter, or stratifications).
"""
formatted = {}
Expand Down
39 changes: 19 additions & 20 deletions tests/framework/results/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _aggregate_state_person_time(x: pd.DataFrame) -> float:
],
ids=["valid_on_collect_metrics", "valid_on_time_step__prepare"],
)
def test_add_observation(kwargs):
def test_register_observation(kwargs):
ctx = ResultsContext()
assert len(ctx.observations) == 0
kwargs["results_formatter"] = lambda: None
Expand All @@ -123,29 +123,28 @@ def test_add_observation(kwargs):
assert len(ctx.observations) == 1


def test_double_add_observation():
"""Tests a double add of the same stratification, this should result in one
additional observation being added to the context."""
def test_register_observation_duplicate_name_raises():
ctx = ResultsContext()
assert len(ctx.observations) == 0
kwargs = {
"name": "living_person_time",
"pop_filter": 'alive == "alive" and undead == False',
"when": "collect_metrics",
"results_formatter": lambda: None,
"stratifications": tuple(),
"aggregator_sources": [],
"aggregator": len,
}
ctx.register_observation(
observation_type=AddingObservation,
**kwargs,
)
ctx.register_observation(
observation_type=AddingObservation,
**kwargs,
name="some-observation-name",
pop_filter="some-pop-filter",
when="some-when",
results_formatter=lambda df: df,
stratifications=(),
aggregator_sources=[],
aggregator=len,
)
assert len(ctx.observations) == 1
with pytest.raises(
ValueError, match="Observation name 'some-observation-name' is already used: "
):
# register a different observation but w/ the same name
ctx.register_observation(
observation_type=ConcatenatingObservation,
name="some-observation-name",
pop_filter="some-other-pop-filter",
when="some-other-when",
)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def setup(self, builder: Builder) -> None:
self.builder_used_for_setup = builder

def register_observations(self, builder):
builder.results.register_adding_observation("test", aggregator=self.counter)
builder.results.register_adding_observation(self.name, aggregator=self.counter)

def create_lookup_tables(self, builder):
return {}
Expand Down

0 comments on commit f7f2b89

Please # to comment.