Skip to content

Commit

Permalink
move results initialization to Observation attributes (#440)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier committed Aug 13, 2024
1 parent 79dbb46 commit 98dbcd4
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 93 deletions.
3 changes: 2 additions & 1 deletion src/vivarium/framework/results/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from vivarium.framework.results.interface import ResultsInterface
from vivarium.framework.results.manager import VALUE_COLUMN, ResultsManager
from vivarium.framework.results.manager import ResultsManager
from vivarium.framework.results.observation import VALUE_COLUMN
from vivarium.framework.results.observer import Observer
98 changes: 14 additions & 84 deletions src/vivarium/framework/results/manager.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
from __future__ import annotations

import itertools
from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
from pandas.api.types import CategoricalDtype

from vivarium.framework.event import Event
from vivarium.framework.results.context import ResultsContext
from vivarium.framework.results.observation import ConcatenatingObservation
from vivarium.framework.results.stratification import Stratification
from vivarium.framework.values import Pipeline
from vivarium.manager import Manager

if TYPE_CHECKING:
from vivarium.framework.engine import Builder


VALUE_COLUMN = "value"


class SourceType(Enum):
COLUMN = 0
VALUE = 1
Expand Down Expand Up @@ -89,48 +82,35 @@ def setup(self, builder: "Builder") -> None:
def on_post_setup(self, _: Event) -> None:
"""Initialize results with 0s DataFrame' for each measure and all stratifications"""
registered_stratifications = self._results_context.stratifications
registered_stratification_names = set(
stratification.name for stratification in registered_stratifications
)

missing_stratifications = {}
unused_stratifications = registered_stratification_names.copy()

used_stratifications = set()
for event_name in self._results_context.observations:
for (
_pop_filter,
all_requested_stratification_names,
event_requested_stratification_names,
), observations in self._results_context.observations[event_name].items():
if event_requested_stratification_names is not None:
used_stratifications |= set(event_requested_stratification_names)
for observation in observations:
measure = observation.name
if all_requested_stratification_names is not None:
df, unused_stratifications = self._initialize_stratified_results(
measure,
all_requested_stratification_names,
registered_stratifications,
registered_stratification_names,
missing_stratifications,
unused_stratifications,
)
else:
# Initialize a completely empty dataframe
df = pd.DataFrame()
self._raw_results[measure] = df
self._raw_results[measure] = observation.results_initializer(
event_requested_stratification_names, registered_stratifications
)

registered_stratification_names = set(
stratification.name for stratification in registered_stratifications
)
unused_stratifications = registered_stratification_names - used_stratifications
if unused_stratifications:
self.logger.info(
"The following stratifications are registered but not used by any "
f"observers: \n{sorted(list(unused_stratifications))}"
)
missing_stratifications = used_stratifications - registered_stratification_names
if missing_stratifications:
# Sort by observer/measure and then by missing stratifiction
sorted_missing = {
key: sorted(list(missing_stratifications[key]))
for key in sorted(missing_stratifications)
}
raise ValueError(
"The following observers are requested to be stratified by "
f"stratifications that are not registered: \n{sorted_missing}"
f"stratifications that are not registered: \n{sorted(list(missing_stratifications))}"
)

def on_time_step_prepare(self, event: Event) -> None:
Expand Down Expand Up @@ -330,56 +310,6 @@ def _get_stratifications(
# Makes sure measure identifiers have fields in the same relative order.
return tuple(sorted(stratifications))

@staticmethod
def _initialize_stratified_results(
measure: str,
all_requested_stratification_names: List[str],
registered_stratifications: List[Stratification],
registered_stratification_names: Set[str],
missing_stratifications: Dict[str, Set[str]],
unused_stratifications: Set[str],
) -> Tuple[pd.DataFrame, Set[str]]:
all_requested_stratification_names = set(all_requested_stratification_names)

# Batch missing stratifications
observer_missing_stratifications = all_requested_stratification_names.difference(
registered_stratification_names
)
if observer_missing_stratifications:
missing_stratifications[measure] = observer_missing_stratifications

# Remove stratifications from the running list of unused stratifications
unused_stratifications = unused_stratifications.difference(
all_requested_stratification_names
)

# Set up the complete index of all used stratifications
requested_and_registered_stratifications = [
stratification
for stratification in registered_stratifications
if stratification.name in all_requested_stratification_names
]
stratification_values = {
stratification.name: stratification.categories
for stratification in requested_and_registered_stratifications
}
if stratification_values:
stratification_names = list(stratification_values.keys())
df = pd.DataFrame(
list(itertools.product(*stratification_values.values())),
columns=stratification_names,
).astype(CategoricalDtype)
else:
# We are aggregating the entire population so create a single-row index
stratification_names = ["stratification"]
df = pd.DataFrame(["all"], columns=stratification_names).astype(CategoricalDtype)

# Initialize a zeros dataframe
df[VALUE_COLUMN] = 0.0
df = df.set_index(stratification_names)

return df, unused_stratifications

def _add_resources(self, target: List[str], target_type: SourceType) -> None:
if len(target) == 0:
return # do nothing on empty lists
Expand Down
52 changes: 52 additions & 0 deletions src/vivarium/framework/results/observation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import itertools
from abc import ABC
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union

import pandas as pd
from pandas.api.types import CategoricalDtype
from pandas.core.groupby import DataFrameGroupBy

from vivarium.framework.results.stratification import Stratification

VALUE_COLUMN = "value"


@dataclass
class BaseObservation(ABC):
Expand All @@ -15,6 +21,7 @@ class BaseObservation(ABC):
- `name`: name of the observation and is the measure it is observing
- `pop_filter`: a filter that is applied to the population before the observation is made
- `when`: the phase that the observation is registered to
- `results_initializer`: method that initializes the results
- `results_gatherer`: method that gathers the new observation results
- `results_updater`: method that updates the results with new observations
- `results_formatter`: method that formats the results
Expand All @@ -23,6 +30,7 @@ class BaseObservation(ABC):
name: str
pop_filter: str
when: str
results_initializer: Callable[..., pd.DataFrame]
results_gatherer: Callable[..., pd.DataFrame]
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame]
Expand Down Expand Up @@ -53,12 +61,21 @@ def __init__(
name=name,
pop_filter=pop_filter,
when=when,
results_initializer=self.initialize_results,
results_gatherer=results_gatherer,
results_updater=results_updater,
results_formatter=results_formatter,
stratifications=None,
)

@staticmethod
def initialize_results(
requested_stratification_names: set[str],
registered_stratifications: List[Stratification],
) -> pd.DataFrame:
"""Initialize an empty dataframe."""
return pd.DataFrame()


class StratifiedObservation(BaseObservation):
"""Container class for managing stratified observations.
Expand Down Expand Up @@ -88,6 +105,7 @@ def __init__(
name=name,
pop_filter=pop_filter,
when=when,
results_initializer=self.initialize_results,
results_gatherer=self.gather_results,
results_updater=results_updater,
results_formatter=results_formatter,
Expand All @@ -96,6 +114,40 @@ def __init__(
self.aggregator_sources = aggregator_sources
self.aggregator = aggregator

@staticmethod
def initialize_results(
requested_stratification_names: set[str],
registered_stratifications: List[Stratification],
) -> pd.DataFrame:
"""Initialize a dataframe of 0s with complete set of stratifications as the index."""

# Set up the complete index of all used stratifications
requested_and_registered_stratifications = [
stratification
for stratification in registered_stratifications
if stratification.name in requested_stratification_names
]
stratification_values = {
stratification.name: stratification.categories
for stratification in requested_and_registered_stratifications
}
if stratification_values:
stratification_names = list(stratification_values.keys())
df = pd.DataFrame(
list(itertools.product(*stratification_values.values())),
columns=stratification_names,
).astype(CategoricalDtype)
else:
# We are aggregating the entire population so create a single-row index
stratification_names = ["stratification"]
df = pd.DataFrame(["all"], columns=stratification_names).astype(CategoricalDtype)

# Initialize a zeros dataframe
df[VALUE_COLUMN] = 0.0
df = df.set_index(stratification_names)

return df

def gather_results(
self,
pop_groups: DataFrameGroupBy,
Expand Down
21 changes: 13 additions & 8 deletions tests/framework/results/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,7 @@ def test_observers_with_missing_stratifications_fail():
"""
components = [QuidditchWinsObserver(), HousePointsObserver(), Hogwarts()]

expected_missing = { # NOTE: keep in alphabetical order
"house_points": ["power_level_group", "student_house"],
"quidditch_wins": ["familiar"],
}
expected_missing = ["familiar", "power_level_group", "student_house"]
expected_log_msg = re.escape(
"The following observers are requested to be stratified by stratifications "
f"that are not registered: \n{expected_missing}"
Expand All @@ -432,10 +429,18 @@ def test_unused_stratifications_are_logged(caplog):
but never actually used by an Observer
The HogwartsResultsStratifier registers "student_house", "familiar", and
"power_level" stratifiers. However, we will only use the HousePointsObserver
component which only requests to be stratified by "student_house" and "power_level"
"power_level_group" stratifiers. However, we will only use the QuidditchWinsObserver
which only uses "familiar" and the MagicalAttributesObserver which only uses
"power_level_group". We would thus expect only "student_house" to be logged
as an unused stratification.
"""
components = [HousePointsObserver(), Hogwarts(), HogwartsResultsStratifier()]
components = [
Hogwarts(),
HogwartsResultsStratifier(),
QuidditchWinsObserver(),
MagicalAttributesObserver(),
]
InteractiveContext(configuration=HARRY_POTTER_CONFIG, components=components)

log_split = caplog.text.split(
Expand All @@ -444,7 +449,7 @@ def test_unused_stratifications_are_logged(caplog):
# Check that the log message is present and only exists one time
assert len(log_split) == 2
# Check that the log message contains the expected Stratifications
assert "['familiar']" in log_split[1]
assert "['student_house']" in log_split[1]


def test_stratified_observation_results():
Expand Down

0 comments on commit 98dbcd4

Please # to comment.