Skip to content

Commit

Permalink
implement to_observe method to Observations (#453)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier committed Aug 13, 2024
1 parent 156bbe8 commit 2dc4870
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 50 deletions.
4 changes: 3 additions & 1 deletion src/vivarium/framework/lookup/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def __call__(self, interpolants: pd.DataFrame) -> pd.DataFrame:
)

if self.categorical_parameters:
sub_tables = interpolants.groupby(list(self.categorical_parameters))
sub_tables = interpolants.groupby(
list(self.categorical_parameters), observed=False
)
else:
sub_tables = [(None, interpolants)]
# specify some numeric type for columns, so they won't be objects but
Expand Down
22 changes: 11 additions & 11 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pandas.core.groupby import DataFrameGroupBy

from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.results.exceptions import ResultsConfigurationError
from vivarium.framework.results.observation import BaseObservation
from vivarium.framework.results.stratification import Stratification
Expand Down Expand Up @@ -137,9 +138,9 @@ def register_observation(
already_used = None
if self.observations:
# NOTE: self.observations is a list where each item is a dictionary
# of the form {event_name: {(pop_filter, stratifications): List[Observation]}}.
# of the form {lifecycle_phase: {(pop_filter, stratifications): List[Observation]}}.
# We use a triple-nested for loop to iterate over only the list of Observations
# (i.e. we do not need the event_name, pop_filter, or stratifications).
# (i.e. we do not need the lifecycle_phase, pop_filter, or stratifications).
for observation_details in self.observations.values():
for observations in observation_details.values():
for observation in observations:
Expand All @@ -155,7 +156,7 @@ def register_observation(
].append(observation)

def gather_results(
self, population: pd.DataFrame, event_name: str
self, population: pd.DataFrame, lifecycle_phase: str, event: Event
) -> Generator[
Tuple[
Optional[pd.DataFrame],
Expand All @@ -171,7 +172,7 @@ def gather_results(
population = stratification(population)

for (pop_filter, stratifications), observations in self.observations[
event_name
lifecycle_phase
].items():
# Results production can be simplified to
# filter -> groupby -> aggregate in all situations we've seen.
Expand All @@ -180,14 +181,13 @@ def gather_results(
yield None, None, None
else:
if stratifications is None:
for observation in observations:
df = observation.results_gatherer(filtered_pop)
yield df, observation.name, observation.results_updater
pop = filtered_pop
else:
pop_groups = self._get_groups(stratifications, filtered_pop)
for observation in observations:
aggregates = observation.results_gatherer(pop_groups, stratifications)
yield aggregates, observation.name, observation.results_updater
pop = self._get_groups(stratifications, filtered_pop)
for observation in observations:
yield observation.observe(
event, pop, stratifications
), observation.name, observation.results_updater

@staticmethod
def _filter_population(population: pd.DataFrame, pop_filter: str) -> pd.DataFrame:
Expand Down
18 changes: 17 additions & 1 deletion src/vivarium/framework/results/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import pandas as pd

from vivarium.framework.event import Event
from vivarium.framework.results.observation import (
AddingObservation,
BaseObservation,
ConcatenatingObservation,
StratifiedObservation,
UnstratifiedObservation,
Expand Down Expand Up @@ -170,6 +170,7 @@ def register_stratified_observation(
excluded_stratifications: List[str] = [],
aggregator_sources: Optional[List[str]] = None,
aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]] = len,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Provide the results system all the information it needs to perform a
stratified observation.
Expand Down Expand Up @@ -202,6 +203,8 @@ def register_stratified_observation(
A list of population view columns to be used in the aggregator.
aggregator
A function that computes the quantity for the observation.
to_observe
A function that determines whether to perform an observation on this Event.
Returns
------
Expand All @@ -222,6 +225,7 @@ def register_stratified_observation(
excluded_stratifications=excluded_stratifications,
aggregator_sources=aggregator_sources,
aggregator=aggregator,
to_observe=to_observe,
)

@staticmethod
Expand Down Expand Up @@ -253,6 +257,7 @@ def register_unstratified_observation(
results_formatter: Callable[
[str, pd.DataFrame], pd.DataFrame
] = lambda measure, results: results,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Provide the results system all the information it needs to perform a
stratified observation.
Expand Down Expand Up @@ -287,6 +292,8 @@ def register_unstratified_observation(
A list of population view columns to be used in the aggregator.
aggregator
A function that computes the quantity for the observation.
to_observe
A function that determines whether to perform an observation on this Event.
Returns
------
Expand All @@ -308,6 +315,7 @@ def register_unstratified_observation(
results_updater=results_updater,
results_gatherer=results_gatherer,
results_formatter=results_formatter,
to_observe=to_observe,
)

def register_adding_observation(
Expand All @@ -324,6 +332,7 @@ def register_adding_observation(
excluded_stratifications: List[str] = [],
aggregator_sources: Optional[List[str]] = None,
aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]] = len,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Provide the results system all the information it needs to perform the observation.
Expand Down Expand Up @@ -353,6 +362,8 @@ def register_adding_observation(
A list of population view columns to be used in the aggregator.
aggregator
A function that computes the quantity for the observation.
to_observe
A function that determines whether to perform an observation on this Event.
Returns
------
Expand All @@ -372,6 +383,7 @@ def register_adding_observation(
excluded_stratifications=excluded_stratifications,
aggregator_sources=aggregator_sources,
aggregator=aggregator,
to_observe=to_observe,
)

def register_concatenating_observation(
Expand All @@ -384,6 +396,7 @@ def register_concatenating_observation(
results_formatter: Callable[
[str, pd.DataFrame], pd.DataFrame
] = lambda measure, results: results,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Provide the results system all the information it needs to perform the observation.
Expand All @@ -403,6 +416,8 @@ def register_concatenating_observation(
A list of the value pipelines that are required by either the pop_filter or the aggregator.
results_formatter
A function that formats the observation results.
to_observe
A function that determines whether to perform an observation on this Event.
Returns
------
Expand All @@ -419,4 +434,5 @@ def register_concatenating_observation(
requires_values=requires_values,
results_formatter=results_formatter,
included_columns=included_columns,
to_observe=to_observe,
)
12 changes: 6 additions & 6 deletions src/vivarium/framework/results/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def get_results(self) -> Dict[str, pd.DataFrame]:
"""Return the measure-specific formatted results in a dictionary.
NOTE: self._results_context.observations is a list where each item is a dictionary
of the form {event_name: {(pop_filter, stratifications): List[Observation]}}.
of the form {lifecycle_phase: {(pop_filter, stratifications): List[Observation]}}.
We use a triple-nested for loop to iterate over only the list of Observations
(i.e. we do not need the event_name, pop_filter, or stratifications).
(i.e. we do not need the lifecycle_phase, pop_filter, or stratifications).
"""
formatted = {}
for observation_details in self._results_context.observations.values():
Expand Down Expand Up @@ -91,11 +91,11 @@ def on_post_setup(self, _: Event) -> None:
registered_stratifications = self._results_context.stratifications

used_stratifications = set()
for event_name in self._results_context.observations:
for lifecycle_phase in self._results_context.observations:
for (
_pop_filter,
event_requested_stratification_names,
), observations in self._results_context.observations[event_name].items():
), observations in self._results_context.observations[lifecycle_phase].items():
if event_requested_stratification_names is not None:
used_stratifications |= set(event_requested_stratification_names)
for observation in observations:
Expand Down Expand Up @@ -132,14 +132,14 @@ def on_time_step_cleanup(self, event: Event) -> None:
def on_collect_metrics(self, event: Event) -> None:
self.gather_results("collect_metrics", event)

def gather_results(self, event_name: str, event: Event) -> None:
def gather_results(self, lifecycle_phase: str, event: Event) -> None:
"""Update the existing results with new results. Any columns in the
results group that are not already in the existing results are initialized
with 0.0.
"""
population = self._prepare_population(event)
for results_group, measure, updater in self._results_context.gather_results(
population, event_name
population, lifecycle_phase, event
):
if results_group is not None and measure is not None and updater is not None:
self._raw_results[measure] = updater(
Expand Down
43 changes: 36 additions & 7 deletions src/vivarium/framework/results/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import itertools
from abc import ABC
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import pandas as pd
from pandas.api.types import CategoricalDtype
from pandas.core.groupby import DataFrameGroupBy

from vivarium.framework.event import Event
from vivarium.framework.results.stratification import Stratification

VALUE_COLUMN = "value"
Expand All @@ -25,6 +26,7 @@ class BaseObservation(ABC):
- `results_gatherer`: method that gathers the new observation results
- `results_updater`: method that updates the results with new observations
- `results_formatter`: method that formats the results
- `to_observe`: method that determines whether to observe an event
"""

name: str
Expand All @@ -35,6 +37,21 @@ class BaseObservation(ABC):
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame]
stratifications: Optional[Tuple[str]]
to_observe: Callable[[Event], bool]

def observe(
self,
event: Event,
df: Union[pd.DataFrame, DataFrameGroupBy],
stratifications: Optional[tuple[str, ...]],
) -> Optional[pd.DataFrame]:
if not self.to_observe(event):
return None
else:
if stratifications is None:
return self.results_gatherer(df)
else:
return self.results_gatherer(df, stratifications)


class UnstratifiedObservation(BaseObservation):
Expand All @@ -46,6 +63,7 @@ class UnstratifiedObservation(BaseObservation):
- `results_gatherer`: method that gathers the new observation results
- `results_updater`: method that updates the results with new observations
- `results_formatter`: method that formats the results
- `to_observe`: method that determines whether to observe an event
"""

def __init__(
Expand All @@ -56,6 +74,7 @@ def __init__(
results_gatherer: Callable[[pd.DataFrame], pd.DataFrame],
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame],
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
to_observe: Callable[[Event], bool] = lambda event: True,
):
super().__init__(
name=name,
Expand All @@ -66,12 +85,13 @@ def __init__(
results_updater=results_updater,
results_formatter=results_formatter,
stratifications=None,
to_observe=to_observe,
)

@staticmethod
def initialize_results(
requested_stratification_names: set[str],
registered_stratifications: List[Stratification],
registered_stratifications: list[Stratification],
) -> pd.DataFrame:
"""Initialize an empty dataframe."""
return pd.DataFrame()
Expand All @@ -88,6 +108,7 @@ class StratifiedObservation(BaseObservation):
- `stratifications`: a tuple of columns for the observation to stratify by
- `aggregator_sources`: a list of the columns to observe
- `aggregator`: a method that aggregates the `aggregator_sources`
- `to_observe`: method that determines whether to observe an event
"""

def __init__(
Expand All @@ -98,18 +119,20 @@ def __init__(
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame],
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
stratifications: Tuple[str, ...],
aggregator_sources: Optional[List[str]],
aggregator_sources: Optional[list[str]],
aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]],
to_observe: Callable[[Event], bool] = lambda event: True,
):
super().__init__(
name=name,
pop_filter=pop_filter,
when=when,
results_initializer=self.initialize_results,
results_gatherer=self.gather_results,
results_gatherer=self.results_gatherer,
results_updater=results_updater,
results_formatter=results_formatter,
stratifications=stratifications,
to_observe=to_observe,
)
self.aggregator_sources = aggregator_sources
self.aggregator = aggregator
Expand Down Expand Up @@ -150,7 +173,7 @@ def initialize_results(

return df

def gather_results(
def results_gatherer(
self,
pop_groups: DataFrameGroupBy,
stratifications: Tuple[str, ...],
Expand Down Expand Up @@ -203,6 +226,7 @@ class AddingObservation(StratifiedObservation):
- `stratifications`: a tuple of columns for the observation to stratify by
- `aggregator_sources`: a list of the columns to observe
- `aggregator`: a method that aggregates the `aggregator_sources`
- `to_observe`: method that determines whether to observe an event
"""

def __init__(
Expand All @@ -214,6 +238,7 @@ def __init__(
stratifications: Tuple[str, ...],
aggregator_sources: Optional[List[str]],
aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]],
to_observe: Callable[[Event], bool] = lambda event: True,
):
super().__init__(
name=name,
Expand All @@ -224,6 +249,7 @@ def __init__(
stratifications=stratifications,
aggregator_sources=aggregator_sources,
aggregator=aggregator,
to_observe=to_observe,
)

@staticmethod
Expand Down Expand Up @@ -252,6 +278,7 @@ class ConcatenatingObservation(UnstratifiedObservation):
- `when`: the phase that the observation is registered to
- `included_columns`: the columns to include in the observation
- `results_formatter`: method that formats the results
- `to_observe`: method that determines whether to observe an event
"""

def __init__(
Expand All @@ -261,14 +288,16 @@ def __init__(
when: str,
included_columns: List[str],
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
to_observe: Callable[[Event], bool] = lambda event: True,
):
super().__init__(
name=name,
pop_filter=pop_filter,
when=when,
results_gatherer=self.gather_results,
results_gatherer=self.results_gatherer,
results_updater=self.concatenate_results,
results_formatter=results_formatter,
to_observe=to_observe,
)
self.included_columns = included_columns

Expand All @@ -280,5 +309,5 @@ def concatenate_results(
return new_observations
return pd.concat([existing_results, new_observations], axis=0).reset_index(drop=True)

def gather_results(self, pop: pd.DataFrame) -> pd.DataFrame:
def results_gatherer(self, pop: pd.DataFrame) -> pd.DataFrame:
return pop[self.included_columns]
Loading

0 comments on commit 2dc4870

Please # to comment.