Skip to content

Commit

Permalink
remove StratifiedObserver (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier committed Aug 13, 2024
1 parent a549fee commit 79dbb46
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/vivarium/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from vivarium.component import Component
from vivarium.framework.artifact import Artifact
from vivarium.framework.configuration import build_model_specification
from vivarium.framework.results.observer import Observer, StratifiedObserver
from vivarium.framework.results.observer import Observer
from vivarium.interface import InteractiveContext
2 changes: 1 addition & 1 deletion src/vivarium/framework/results/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from vivarium.framework.results.interface import ResultsInterface
from vivarium.framework.results.manager import VALUE_COLUMN, ResultsManager
from vivarium.framework.results.observer import Observer, StratifiedObserver
from vivarium.framework.results.observer import Observer
25 changes: 11 additions & 14 deletions src/vivarium/framework/results/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ def __init__(self) -> None:
super().__init__()
self.results_dir = None

@property
def configuration_defaults(self) -> Dict[str, Any]:
return {
"stratification": {
self.name.split("_observer")[0]: {
"exclude": [],
"include": [],
},
},
}

@abstractmethod
def register_observations(self, builder: Builder) -> None:
"""(Required). Register observations with within each observer."""
Expand All @@ -34,17 +45,3 @@ def get_formatter_attributes(self, builder: Builder) -> None:
.get("output_data", {})
.get("results_directory", None)
)


# TODO: Move this property into Observer and get rid of StratifiedObserver
class StratifiedObserver(Observer):
@property
def configuration_defaults(self) -> Dict[str, Any]:
return {
"stratification": {
self.name.split("_observer")[0]: {
"exclude": [],
"include": [],
},
},
}
10 changes: 8 additions & 2 deletions tests/framework/components/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,17 @@ def test_component_manager_add_components_duplicated(components):
config = build_simulation_configuration()
cm = ComponentManager()
cm.configuration = config
with pytest.raises(ComponentConfigError, match="duplicate name"):
with pytest.raises(
ComponentConfigError,
match=f"Attempting to add a component with duplicate name: {MockComponentA()}",
):
cm.add_managers(components)

config = build_simulation_configuration()
cm = ComponentManager()
cm.configuration = config
with pytest.raises(ComponentConfigError, match="duplicate name"):
with pytest.raises(
ComponentConfigError,
match=f"Attempting to add a component with duplicate name: {MockComponentA()}",
):
cm.add_components(components)
14 changes: 7 additions & 7 deletions tests/framework/results/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vivarium.framework.engine import Builder
from vivarium.framework.population import SimulantData
from vivarium.framework.results import VALUE_COLUMN
from vivarium.framework.results.observer import Observer, StratifiedObserver
from vivarium.framework.results.observer import Observer

NAME = "hogwarts_house"
SOURCES = ["first_name", "last_name"]
Expand Down Expand Up @@ -108,7 +108,7 @@ def on_time_step(self, pop_data: SimulantData) -> None:
self.population_view.update(update)


class HousePointsObserver(StratifiedObserver):
class HousePointsObserver(Observer):
"""Observer that is stratified by multiple columns (the defaults,
'student_house' and 'power_level_group')
"""
Expand All @@ -125,7 +125,7 @@ def register_observations(self, builder: Builder) -> None:
)


class FullyFilteredHousePointsObserver(StratifiedObserver):
class FullyFilteredHousePointsObserver(Observer):
"""Same as `HousePointsObserver but with a filter that leaves no simulants"""

def register_observations(self, builder: Builder) -> None:
Expand All @@ -140,7 +140,7 @@ def register_observations(self, builder: Builder) -> None:
)


class QuidditchWinsObserver(StratifiedObserver):
class QuidditchWinsObserver(Observer):
"""Observer that is stratified by a single column ('familiar')"""

def register_observations(self, builder: Builder) -> None:
Expand All @@ -157,7 +157,7 @@ def register_observations(self, builder: Builder) -> None:
)


class NoStratificationsQuidditchWinsObserver(StratifiedObserver):
class NoStratificationsQuidditchWinsObserver(Observer):
"""Same as above but no stratifications at all"""

def register_observations(self, builder: Builder) -> None:
Expand All @@ -173,7 +173,7 @@ def register_observations(self, builder: Builder) -> None:
)


class MagicalAttributesObserver(StratifiedObserver):
class MagicalAttributesObserver(Observer):
"""Observer whose aggregator returns a pd.Series instead of a float (which in
turn results in a dataframe with multiple columns instead of just one
'value' column)
Expand Down Expand Up @@ -202,7 +202,7 @@ def register_observations(self, builder: Builder) -> None:
)


class CatBombObserver(StratifiedObserver):
class CatBombObserver(Observer):
"""Observer that counts the number of feral cats per house"""

def register_observations(self, builder: Builder) -> None:
Expand Down
23 changes: 3 additions & 20 deletions tests/framework/results/test_observer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import pytest
from layered_config_tree import LayeredConfigTree

from vivarium.framework.results.observer import Observer, StratifiedObserver
from vivarium.framework.results.observer import Observer


class TestObserver(Observer):
def register_observations(self, builder):
pass


class TestDefaultStratifiedObserver(StratifiedObserver):
class TestDefaultObserverStratifications(Observer):
def register_observations(self, builder):
pass


class TestStratifiedObserver(StratifiedObserver):
class TestObserverStratifications(Observer):
def register_observations(self, builder):
pass

Expand Down Expand Up @@ -50,20 +50,3 @@ def test_get_formatter_attributes(is_interactive, results_dir, mocker):
observer.get_formatter_attributes(builder)

assert observer.results_dir == results_dir


@pytest.mark.parametrize(
"observer, name, expected_configuration_defaults",
[
(
TestDefaultStratifiedObserver(),
"test_default_stratified_observer",
{"stratification": {"test_default_stratified": {"exclude": [], "include": []}}},
),
(TestStratifiedObserver(), "test_stratified_observer", {"foo": "bar"}),
],
)
def test_stratified_observer_instantiation(observer, name, expected_configuration_defaults):
obs = observer
assert obs.name == name
assert obs.configuration_defaults == expected_configuration_defaults
9 changes: 6 additions & 3 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@

import pandas as pd

from vivarium import Component, Observer, StratifiedObserver
from vivarium import Component, Observer
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.results import VALUE_COLUMN


class MockComponentA(Observer):
@property
def name(self) -> str:
return self._name

@property
def configuration_defaults(self):
return {}

def __init__(self, *args, name="mock_component_a"):
super().__init__()
self._name = name
Expand All @@ -30,7 +33,7 @@ def __eq__(self, other: Any) -> bool:
return type(self) == type(other) and self.name == other.name


class MockComponentB(StratifiedObserver):
class MockComponentB(Observer):
@property
def name(self) -> str:
return self._name
Expand Down

0 comments on commit 79dbb46

Please # to comment.