Skip to content

Commit

Permalink
Feature/sbachmei/refactor observation registrations and other minor u…
Browse files Browse the repository at this point in the history
…pdates (#437)
  • Loading branch information
stevebachmeier committed Aug 13, 2024
1 parent 6102159 commit a549fee
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 383 deletions.
110 changes: 9 additions & 101 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from __future__ import annotations

from collections import defaultdict
from typing import Callable, Generator, List, Optional, Tuple, Union
from typing import Callable, Generator, List, Optional, Tuple, Type, Union

import pandas as pd
from pandas.core.groupby import DataFrameGroupBy

from vivarium.framework.engine import Builder
from vivarium.framework.results.exceptions import ResultsConfigurationError
from vivarium.framework.results.observation import (
AddingObservation,
ConcatenatingObservation,
StratifiedObservation,
UnstratifiedObservation,
)
from vivarium.framework.results.observation import BaseObservation
from vivarium.framework.results.stratification import Stratification


Expand Down Expand Up @@ -100,93 +95,18 @@ def add_stratification(
stratification = Stratification(name, sources, categories, mapper, is_vectorized)
self.stratifications.append(stratification)

def register_stratified_observation(
def register_observation(
self,
observation_type: Type[BaseObservation],
name: str,
pop_filter: str,
when: str,
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame],
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
additional_stratifications: List[str],
excluded_stratifications: List[str],
aggregator_sources: Optional[List[str]],
aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]],
**kwargs,
) -> None:
stratifications = self._get_stratifications(
additional_stratifications, excluded_stratifications
)
observation = StratifiedObservation(
name=name,
pop_filter=pop_filter,
when=when,
results_updater=results_updater,
results_formatter=results_formatter,
stratifications=stratifications,
aggregator_sources=aggregator_sources,
aggregator=aggregator,
)
self.observations[when][(pop_filter, stratifications)].append(observation)

def register_unstratified_observation(
self,
name: str,
pop_filter: str,
when: str,
results_gatherer: Callable[[pd.DataFrame], pd.DataFrame],
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame],
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
) -> None:
observation = UnstratifiedObservation(
name=name,
pop_filter=pop_filter,
when=when,
results_gatherer=results_gatherer,
results_updater=results_updater,
results_formatter=results_formatter,
)
self.observations[when][(pop_filter, None)].append(observation)

def register_adding_observation(
self,
name: str,
pop_filter: str,
when: str,
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
additional_stratifications: List[str],
excluded_stratifications: List[str],
aggregator_sources: Optional[List[str]],
aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]],
) -> None:
stratifications = self._get_stratifications(
additional_stratifications, excluded_stratifications
)
observation = AddingObservation(
name=name,
pop_filter=pop_filter,
when=when,
results_formatter=results_formatter,
stratifications=stratifications,
aggregator_sources=aggregator_sources,
aggregator=aggregator,
)
self.observations[when][(pop_filter, stratifications)].append(observation)

def register_concatenating_observation(
self,
name: str,
pop_filter: str,
when: str,
included_columns: List[str],
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
) -> None:
observation = ConcatenatingObservation(
name=name,
pop_filter=pop_filter,
when=when,
included_columns=included_columns,
results_formatter=results_formatter,
)
self.observations[when][(pop_filter, None)].append(observation)
observation = observation_type(name=name, pop_filter=pop_filter, when=when, **kwargs)
self.observations[observation.when][
(observation.pop_filter, observation.stratifications)
].append(observation)

def gather_results(
self, population: pd.DataFrame, event_name: str
Expand Down Expand Up @@ -223,18 +143,6 @@ def gather_results(
aggregates = observation.results_gatherer(pop_groups, stratifications)
yield aggregates, observation.name, observation.results_updater

def _get_stratifications(
self,
additional_stratifications: List[str] = [],
excluded_stratifications: List[str] = [],
) -> Tuple[str, ...]:
stratifications = list(
set(self.default_stratifications) - set(excluded_stratifications)
| set(additional_stratifications)
)
# Makes sure measure identifiers have fields in the same relative order.
return tuple(sorted(stratifications))

@staticmethod
def _filter_population(population: pd.DataFrame, pop_filter: str) -> pd.DataFrame:
return population.query(pop_filter) if pop_filter else population
Expand Down
74 changes: 48 additions & 26 deletions src/vivarium/framework/results/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,22 @@

import pandas as pd

from vivarium.framework.results.observation import (
AddingObservation,
BaseObservation,
ConcatenatingObservation,
StratifiedObservation,
UnstratifiedObservation,
)

if TYPE_CHECKING:
# cyclic import
from vivarium.framework.results.manager import ResultsManager


def _raise_missing_unstratified_observation_results_gatherer(*args, **kwargs) -> pd.DataFrame:
raise RuntimeError(
"An UnstratifiedObservation has been registered without a `results_gatherer` "
"Callable which is required."
)


def _raise_missing_unstratified_observation_results_updater(*args, **kwargs) -> pd.DataFrame:
raise RuntimeError(
"An UnstratifiedObservation has been registered without a `results_updater` "
"Callable which is required."
)


def _raise_missing_stratified_observation_results_updater(*args, **kwargs) -> pd.DataFrame:
raise RuntimeError(
"A StratifiedObservation has been registered without a `results_updater` "
"Callable which is required."
)
def _required_function_placeholder(*args, **kwargs) -> pd.DataFrame:
"""Placeholder function to indicate that a required function is missing."""
return pd.DataFrame()


class ResultsInterface:
Expand Down Expand Up @@ -170,7 +162,7 @@ def register_stratified_observation(
requires_values: List[str] = [],
results_updater: Callable[
[pd.DataFrame, pd.DataFrame], pd.DataFrame
] = _raise_missing_stratified_observation_results_updater,
] = _required_function_placeholder,
results_formatter: Callable[
[str, pd.DataFrame], pd.DataFrame
] = lambda measure, results: results,
Expand Down Expand Up @@ -215,7 +207,10 @@ def register_stratified_observation(
------
None
"""
self._manager.register_stratified_observation(
self._check_for_required_callables(name, {"results_updater": results_updater})
self._manager.register_observation(
observation_type=StratifiedObservation,
is_stratified=True,
name=name,
pop_filter=pop_filter,
when=when,
Expand All @@ -229,6 +224,19 @@ def register_stratified_observation(
aggregator=aggregator,
)

@staticmethod
def _check_for_required_callables(
observation_name: str, required_callables: Dict[str, Callable]
) -> None:
missing = []
for arg_name, callable in required_callables.items():
if callable == _required_function_placeholder:
missing.append(arg_name)
if len(missing) > 0:
raise ValueError(
f"Observation '{observation_name}' is missing required callable(s): {missing}"
)

def register_unstratified_observation(
self,
name: str,
Expand All @@ -238,10 +246,10 @@ def register_unstratified_observation(
requires_values: List[str] = [],
results_gatherer: Callable[
[pd.DataFrame], pd.DataFrame
] = _raise_missing_unstratified_observation_results_gatherer,
] = _required_function_placeholder,
results_updater: Callable[
[pd.DataFrame, pd.DataFrame], pd.DataFrame
] = _raise_missing_unstratified_observation_results_updater,
] = _required_function_placeholder,
results_formatter: Callable[
[str, pd.DataFrame], pd.DataFrame
] = lambda measure, results: results,
Expand Down Expand Up @@ -284,7 +292,14 @@ def register_unstratified_observation(
------
None
"""
self._manager.register_unstratified_observation(
required_callables = {
"results_gatherer": results_gatherer,
"results_updater": results_updater,
}
self._check_for_required_callables(name, required_callables)
self._manager.register_observation(
observation_type=UnstratifiedObservation,
is_stratified=False,
name=name,
pop_filter=pop_filter,
when=when,
Expand Down Expand Up @@ -343,7 +358,10 @@ def register_adding_observation(
------
None
"""
self._manager.register_adding_observation(

self._manager.register_observation(
observation_type=AddingObservation,
is_stratified=True,
name=name,
pop_filter=pop_filter,
when=when,
Expand Down Expand Up @@ -390,11 +408,15 @@ def register_concatenating_observation(
------
None
"""
self._manager.register_concatenating_observation(
included_columns = ["event_time"] + requires_columns + requires_values
self._manager.register_observation(
observation_type=ConcatenatingObservation,
is_stratified=False,
name=name,
pop_filter=pop_filter,
when=when,
requires_columns=requires_columns,
requires_values=requires_values,
results_formatter=results_formatter,
included_columns=included_columns,
)
Loading

0 comments on commit a549fee

Please # to comment.