Skip to content

Commit

Permalink
fix broken repro test (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier committed Aug 13, 2024
1 parent 79ebbee commit 7e81f4b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/vivarium/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,9 @@ def __init__(self) -> None:
self._name: str = ""
self._sub_components: List["Component"] = []
self.logger: Optional[Logger] = None
self.get_value_columns: Optional[Callable[[Union[str, pd.DataFrame]], List[str]]] = (
None
)
self.get_value_columns: Optional[
Callable[[Union[str, pd.DataFrame]], List[str]]
] = None
self.configuration: Optional[LayeredConfigTree] = None
self.population_view: Optional[PopulationView] = None
self.lookup_tables: Dict[str, LookupTable] = {}
Expand Down
24 changes: 22 additions & 2 deletions src/vivarium/examples/disease_model/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pandas as pd

from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.results import Observer as Observer_


class Observer(Component):
class Observer(Observer_):
##############
# Properties #
##############
Expand All @@ -23,6 +23,26 @@ def configuration_defaults(self) -> Dict[str, Any]:
def columns_required(self) -> Optional[List[str]]:
return ["age", "alive"]

def register_observations(self, builder: Builder) -> None:
builder.results.register_adding_observation(
name="total_population_alive",
requires_columns=["alive"],
pop_filter='alive == "alive"',
)
builder.results.register_adding_observation(
name="total_population_dead",
requires_columns=["alive"],
pop_filter='alive == "dead"',
)
builder.results.register_adding_observation(
name="years_of_life_lost",
requires_columns=["age", "alive"],
aggregator=self.calculate_ylls,
)

def calculate_ylls(self, df: pd.DataFrame) -> float:
return (self.life_expectancy - df.loc[df["alive"] == "dead", "age"]).sum()

#####################
# Lifecycle methods #
#####################
Expand Down
14 changes: 8 additions & 6 deletions tests/framework/randomness/test_reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ def test_reproducibility(tmp_path, disease_model_spec):
check=True,
)

files = [file for file in results_dir.rglob("**/*.hdf")]
assert len(files) == 2
df1 = pd.read_hdf(files[0]).drop(columns="simulation_run_time")
df2 = pd.read_hdf(files[1]).drop(columns="simulation_run_time")

assert df1.equals(df2)
files = [file for file in results_dir.rglob("**/*.parquet")]
assert len(files) == 6
for filename in ["total_population_alive", "total_population_dead", "years_of_life_lost"]:
df_paths = [file for file in files if file.stem == filename]
df1 = pd.read_parquet(df_paths[0])
df2 = pd.read_parquet(df_paths[1])

assert df1.equals(df2)

0 comments on commit 7e81f4b

Please # to comment.