diff --git a/src/vivarium/component.py b/src/vivarium/component.py index 9ddcbef06..df700af85 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -341,9 +341,9 @@ def __init__(self) -> None: self._name: str = "" self._sub_components: List["Component"] = [] self.logger: Optional[Logger] = None - self.get_value_columns: Optional[Callable[[Union[str, pd.DataFrame]], List[str]]] = ( - None - ) + self.get_value_columns: Optional[ + Callable[[Union[str, pd.DataFrame]], List[str]] + ] = None self.configuration: Optional[LayeredConfigTree] = None self.population_view: Optional[PopulationView] = None self.lookup_tables: Dict[str, LookupTable] = {} diff --git a/src/vivarium/examples/disease_model/observer.py b/src/vivarium/examples/disease_model/observer.py index c60f5c54e..fd59b6154 100644 --- a/src/vivarium/examples/disease_model/observer.py +++ b/src/vivarium/examples/disease_model/observer.py @@ -2,11 +2,11 @@ import pandas as pd -from vivarium import Component from vivarium.framework.engine import Builder +from vivarium.framework.results import Observer as Observer_ -class Observer(Component): +class Observer(Observer_): ############## # Properties # ############## @@ -23,6 +23,26 @@ def configuration_defaults(self) -> Dict[str, Any]: def columns_required(self) -> Optional[List[str]]: return ["age", "alive"] + def register_observations(self, builder: Builder) -> None: + builder.results.register_adding_observation( + name="total_population_alive", + requires_columns=["alive"], + pop_filter='alive == "alive"', + ) + builder.results.register_adding_observation( + name="total_population_dead", + requires_columns=["alive"], + pop_filter='alive == "dead"', + ) + builder.results.register_adding_observation( + name="years_of_life_lost", + requires_columns=["age", "alive"], + aggregator=self.calculate_ylls, + ) + + def calculate_ylls(self, df: pd.DataFrame) -> float: + return (self.life_expectancy - df.loc[df["alive"] == "dead", "age"]).sum() + ##################### # Lifecycle methods # ##################### diff --git a/tests/framework/randomness/test_reproducibility.py b/tests/framework/randomness/test_reproducibility.py index f6365e249..8db9c908c 100644 --- a/tests/framework/randomness/test_reproducibility.py +++ b/tests/framework/randomness/test_reproducibility.py @@ -22,9 +22,11 @@ def test_reproducibility(tmp_path, disease_model_spec): check=True, ) - files = [file for file in results_dir.rglob("**/*.hdf")] - assert len(files) == 2 - df1 = pd.read_hdf(files[0]).drop(columns="simulation_run_time") - df2 = pd.read_hdf(files[1]).drop(columns="simulation_run_time") - - assert df1.equals(df2) + files = [file for file in results_dir.rglob("**/*.parquet")] + assert len(files) == 6 + for filename in ["total_population_alive", "total_population_dead", "years_of_life_lost"]: + df_paths = [file for file in files if file.stem == filename] + df1 = pd.read_parquet(df_paths[0]) + df2 = pd.read_parquet(df_paths[1]) + + assert df1.equals(df2)