From 79dbb46b3e33adc1c790bc221b65b7a29f9d0668 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:22:51 -0600 Subject: [PATCH] remove StratifiedObserver (#439) --- src/vivarium/__init__.py | 2 +- src/vivarium/framework/results/__init__.py | 2 +- src/vivarium/framework/results/observer.py | 25 ++++++++++------------ tests/framework/components/test_manager.py | 10 +++++++-- tests/framework/results/helpers.py | 14 ++++++------ tests/framework/results/test_observer.py | 23 +++----------------- tests/helpers.py | 9 +++++--- 7 files changed, 37 insertions(+), 48 deletions(-) diff --git a/src/vivarium/__init__.py b/src/vivarium/__init__.py index 0876ebd85..0cc64d170 100644 --- a/src/vivarium/__init__.py +++ b/src/vivarium/__init__.py @@ -15,5 +15,5 @@ from vivarium.component import Component from vivarium.framework.artifact import Artifact from vivarium.framework.configuration import build_model_specification -from vivarium.framework.results.observer import Observer, StratifiedObserver +from vivarium.framework.results.observer import Observer from vivarium.interface import InteractiveContext diff --git a/src/vivarium/framework/results/__init__.py b/src/vivarium/framework/results/__init__.py index 283c9ab75..77c93d311 100644 --- a/src/vivarium/framework/results/__init__.py +++ b/src/vivarium/framework/results/__init__.py @@ -1,3 +1,3 @@ from vivarium.framework.results.interface import ResultsInterface from vivarium.framework.results.manager import VALUE_COLUMN, ResultsManager -from vivarium.framework.results.observer import Observer, StratifiedObserver +from vivarium.framework.results.observer import Observer diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 3a7a63c82..4efbb4536 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -17,6 +17,17 @@ def __init__(self) -> None: super().__init__() self.results_dir = None + @property + def configuration_defaults(self) -> Dict[str, Any]: + return { + "stratification": { + self.name.split("_observer")[0]: { + "exclude": [], + "include": [], + }, + }, + } + @abstractmethod def register_observations(self, builder: Builder) -> None: """(Required). Register observations with within each observer.""" @@ -34,17 +45,3 @@ def get_formatter_attributes(self, builder: Builder) -> None: .get("output_data", {}) .get("results_directory", None) ) - - -# TODO: Move this property into Observer and get rid of StratifiedObserver -class StratifiedObserver(Observer): - @property - def configuration_defaults(self) -> Dict[str, Any]: - return { - "stratification": { - self.name.split("_observer")[0]: { - "exclude": [], - "include": [], - }, - }, - } diff --git a/tests/framework/components/test_manager.py b/tests/framework/components/test_manager.py index af904e419..81ef509f4 100644 --- a/tests/framework/components/test_manager.py +++ b/tests/framework/components/test_manager.py @@ -286,11 +286,17 @@ def test_component_manager_add_components_duplicated(components): config = build_simulation_configuration() cm = ComponentManager() cm.configuration = config - with pytest.raises(ComponentConfigError, match="duplicate name"): + with pytest.raises( + ComponentConfigError, + match=f"Attempting to add a component with duplicate name: {MockComponentA()}", + ): cm.add_managers(components) config = build_simulation_configuration() cm = ComponentManager() cm.configuration = config - with pytest.raises(ComponentConfigError, match="duplicate name"): + with pytest.raises( + ComponentConfigError, + match=f"Attempting to add a component with duplicate name: {MockComponentA()}", + ): cm.add_components(components) diff --git a/tests/framework/results/helpers.py b/tests/framework/results/helpers.py index aeb36bcc5..37c5fd295 100644 --- a/tests/framework/results/helpers.py +++ b/tests/framework/results/helpers.py @@ -8,7 +8,7 @@ from vivarium.framework.engine import Builder from vivarium.framework.population import SimulantData from vivarium.framework.results import VALUE_COLUMN -from vivarium.framework.results.observer import Observer, StratifiedObserver +from vivarium.framework.results.observer import Observer NAME = "hogwarts_house" SOURCES = ["first_name", "last_name"] @@ -108,7 +108,7 @@ def on_time_step(self, pop_data: SimulantData) -> None: self.population_view.update(update) -class HousePointsObserver(StratifiedObserver): +class HousePointsObserver(Observer): """Observer that is stratified by multiple columns (the defaults, 'student_house' and 'power_level_group') """ @@ -125,7 +125,7 @@ def register_observations(self, builder: Builder) -> None: ) -class FullyFilteredHousePointsObserver(StratifiedObserver): +class FullyFilteredHousePointsObserver(Observer): """Same as `HousePointsObserver but with a filter that leaves no simulants""" def register_observations(self, builder: Builder) -> None: @@ -140,7 +140,7 @@ def register_observations(self, builder: Builder) -> None: ) -class QuidditchWinsObserver(StratifiedObserver): +class QuidditchWinsObserver(Observer): """Observer that is stratified by a single column ('familiar')""" def register_observations(self, builder: Builder) -> None: @@ -157,7 +157,7 @@ def register_observations(self, builder: Builder) -> None: ) -class NoStratificationsQuidditchWinsObserver(StratifiedObserver): +class NoStratificationsQuidditchWinsObserver(Observer): """Same as above but no stratifications at all""" def register_observations(self, builder: Builder) -> None: @@ -173,7 +173,7 @@ def register_observations(self, builder: Builder) -> None: ) -class MagicalAttributesObserver(StratifiedObserver): +class MagicalAttributesObserver(Observer): """Observer whose aggregator returns a pd.Series instead of a float (which in turn results in a dataframe with multiple columns instead of just one 'value' column) @@ -202,7 +202,7 @@ def register_observations(self, builder: Builder) -> None: ) -class CatBombObserver(StratifiedObserver): +class CatBombObserver(Observer): """Observer that counts the number of feral cats per house""" def register_observations(self, builder: Builder) -> None: diff --git a/tests/framework/results/test_observer.py b/tests/framework/results/test_observer.py index 41f190e9d..0ad4d1154 100644 --- a/tests/framework/results/test_observer.py +++ b/tests/framework/results/test_observer.py @@ -1,7 +1,7 @@ import pytest from layered_config_tree import LayeredConfigTree -from vivarium.framework.results.observer import Observer, StratifiedObserver +from vivarium.framework.results.observer import Observer class TestObserver(Observer): @@ -9,12 +9,12 @@ def register_observations(self, builder): pass -class TestDefaultStratifiedObserver(StratifiedObserver): +class TestDefaultObserverStratifications(Observer): def register_observations(self, builder): pass -class TestStratifiedObserver(StratifiedObserver): +class TestObserverStratifications(Observer): def register_observations(self, builder): pass @@ -50,20 +50,3 @@ def test_get_formatter_attributes(is_interactive, results_dir, mocker): observer.get_formatter_attributes(builder) assert observer.results_dir == results_dir - - -@pytest.mark.parametrize( - "observer, name, expected_configuration_defaults", - [ - ( - TestDefaultStratifiedObserver(), - "test_default_stratified_observer", - {"stratification": {"test_default_stratified": {"exclude": [], "include": []}}}, - ), - (TestStratifiedObserver(), "test_stratified_observer", {"foo": "bar"}), - ], -) -def test_stratified_observer_instantiation(observer, name, expected_configuration_defaults): - obs = observer - assert obs.name == name - assert obs.configuration_defaults == expected_configuration_defaults diff --git a/tests/helpers.py b/tests/helpers.py index bf2835c8b..ba7e90ff9 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -2,11 +2,10 @@ import pandas as pd -from vivarium import Component, Observer, StratifiedObserver +from vivarium import Component, Observer from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.population import SimulantData -from vivarium.framework.results import VALUE_COLUMN class MockComponentA(Observer): @@ -14,6 +13,10 @@ class MockComponentA(Observer): def name(self) -> str: return self._name + @property + def configuration_defaults(self): + return {} + def __init__(self, *args, name="mock_component_a"): super().__init__() self._name = name @@ -30,7 +33,7 @@ def __eq__(self, other: Any) -> bool: return type(self) == type(other) and self.name == other.name -class MockComponentB(StratifiedObserver): +class MockComponentB(Observer): @property def name(self) -> str: return self._name