Skip to content

Commit

Permalink
Feature/sbachmei/mic 5163 exclude unwanted results (#460)
Browse files Browse the repository at this point in the history
* implement exclusions
* handle name collisions when stratifying
* allow component_type to be any sequence
* Add tests for stratification registration through interface
  • Loading branch information
stevebachmeier committed Aug 13, 2024
1 parent c211cc7 commit 0bd9a2e
Show file tree
Hide file tree
Showing 13 changed files with 671 additions and 278 deletions.
1 change: 1 addition & 0 deletions docs/source/tutorials/exploration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ configuration by simply printing it.
sim = get_disease_model_simulation()

del sim.configuration['input_data']
del sim.configuration['stratification']['excluded_categories']

.. testcode:: configuration

Expand Down
12 changes: 6 additions & 6 deletions src/vivarium/framework/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
"""

import inspect
import typing
from typing import Any, Dict, Iterator, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Sequence, Tuple, Union

from layered_config_tree import (
ConfigurationError,
Expand All @@ -31,7 +30,7 @@
from vivarium.framework.lifecycle import LifeCycleManager
from vivarium.manager import Manager

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from vivarium.framework.engine import Builder


Expand Down Expand Up @@ -181,7 +180,7 @@ def add_components(self, components: Union[List[Component], Tuple[Component]]) -
self._components.add(c)

def get_components_by_type(
self, component_type: Union[type, Tuple[type, ...]]
self, component_type: Union[type, Sequence[type]]
) -> List[Component]:
"""Get all components that are an instance of ``component_type``.
Expand All @@ -196,7 +195,8 @@ def get_components_by_type(
A list of components of type ``component_type``.
"""
return [c for c in self._components if isinstance(c, component_type)]
# Convert component_type to a tuple for isinstance
return [c for c in self._components if isinstance(c, tuple(component_type))]

def get_component(self, name: str) -> Component:
"""Get the component with name ``name``.
Expand Down Expand Up @@ -348,7 +348,7 @@ def get_component(self, name: str) -> Component:
return self._manager.get_component(name)

def get_components_by_type(
self, component_type: Union[type, Tuple[type, ...]]
self, component_type: Union[type, Sequence[type]]
) -> List[Component]:
"""Get all components that are an instance of ``component_type``.
Expand Down
137 changes: 112 additions & 25 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from vivarium.framework.event import Event
from vivarium.framework.results.exceptions import ResultsConfigurationError
from vivarium.framework.results.observation import BaseObservation
from vivarium.framework.results.stratification import Stratification
from vivarium.framework.results.stratification import (
Stratification,
get_mapped_col_name,
get_original_col_name,
)


class ResultsContext:
Expand All @@ -25,6 +29,7 @@ class ResultsContext:
def __init__(self) -> None:
self.default_stratifications: List[str] = []
self.stratifications: List[Stratification] = []
self.excluded_categories: dict[str, list[str]] = {}
# keys are event names: [
# "time_step__prepare",
# "time_step",
Expand All @@ -42,6 +47,9 @@ def name(self) -> str:

def setup(self, builder: Builder) -> None:
self.logger = builder.logging.get_logger(self.name)
self.excluded_categories = (
builder.configuration.stratification.excluded_categories.to_dict()
)

# noinspection PyAttributeOutsideInit
def set_default_stratifications(self, default_grouping_columns: List[str]) -> None:
Expand All @@ -57,6 +65,7 @@ def add_stratification(
name: str,
sources: List[str],
categories: List[str],
excluded_categories: Optional[List[str]],
mapper: Optional[Callable[[Union[pd.Series[str], pd.DataFrame]], pd.Series[str]]],
is_vectorized: bool,
) -> None:
Expand All @@ -71,6 +80,9 @@ def add_stratification(
categorization.
categories
List of string values that the mapper is allowed to output.
excluded_categories
List of mapped string values to be excluded from results processing.
If None (the default), will use exclusions as defined in the configuration.
mapper
A callable that emits values in `categories` given inputs from columns
and values in the `requires_columns` and `requires_values`, respectively.
Expand Down Expand Up @@ -100,7 +112,35 @@ def add_stratification(
raise ValueError(
f"Found duplicate categories in stratification '{name}': {categories}."
)
stratification = Stratification(name, sources, categories, mapper, is_vectorized)

# Handle excluded categories. If excluded_categories are explicitly
# passed in, we use that instead of what is in the model spec.
to_exclude = (
excluded_categories
if excluded_categories is not None
else self.excluded_categories.get(name, [])
)
unknown_exclusions = set(to_exclude) - set(categories)
if len(unknown_exclusions) > 0:
raise ValueError(
f"Excluded categories {unknown_exclusions} not found in categories "
f"{categories} for stratification '{name}'."
)
if to_exclude:
self.logger.debug(
f"'{name}' has category exclusion requests: {to_exclude}\n"
"Removing these from the allowable categories."
)
categories = [category for category in categories if category not in to_exclude]

stratification = Stratification(
name,
sources,
categories,
to_exclude,
mapper,
is_vectorized,
)
self.stratifications.append(stratification)

def register_observation(
Expand Down Expand Up @@ -166,44 +206,91 @@ def gather_results(
None,
None,
]:
# Optimization: We store all the producers by pop_filter and stratifications
# so that we only have to apply them once each time we compute results.
"""Generate current results for all observations at this lifecycle phase and event."""

for stratification in self.stratifications:
population = stratification(population)
# Add new columns of mapped values to the population to prevent name collisions
new_column = get_mapped_col_name(stratification.name)
if new_column in population.columns:
raise ValueError(
f"Stratification column '{new_column}' "
"already exists in the state table or as a pipeline which is a required "
"name for stratifying results - choose a different name."
)
population[new_column] = stratification(population)

for (pop_filter, stratifications), observations in self.observations[
# Optimization: We store all the producers by pop_filter and stratifications
# so that we only have to apply them once each time we compute results.
for (pop_filter, stratification_names), observations in self.observations[
lifecycle_phase
].items():
# Results production can be simplified to
# filter -> groupby -> aggregate in all situations we've seen.
filtered_pop = self._filter_population(population, pop_filter)
filtered_pop = self._filter_population(
population, pop_filter, stratification_names
)
if filtered_pop.empty:
yield None, None, None
else:
if stratifications is None:
if stratification_names is None:
pop = filtered_pop
else:
pop = self._get_groups(stratifications, filtered_pop)
pop = self._get_groups(stratification_names, filtered_pop)
for observation in observations:
yield observation.observe(
event, pop, stratifications
), observation.name, observation.results_updater
results = observation.observe(event, pop, stratification_names)
if results is not None:
self._rename_stratification_columns(results)

@staticmethod
def _filter_population(population: pd.DataFrame, pop_filter: str) -> pd.DataFrame:
return population.query(pop_filter) if pop_filter else population
yield (results, observation.name, observation.results_updater)

def _filter_population(
self,
population: pd.DataFrame,
pop_filter: str,
stratification_names: Optional[tuple[str, ...]],
) -> pd.DataFrame:
"""Filter the population based on the filter string as well as any
excluded stratification categories
"""
pop = population.query(pop_filter) if pop_filter else population.copy()
if stratification_names:
# Drop all rows in the mapped_stratification columns that have NaN values
# (which only exist if the mapper returned an excluded category).
pop = pop.dropna(
subset=[
get_mapped_col_name(stratification)
for stratification in stratification_names
]
)
return pop

@staticmethod
def _get_groups(
stratifications: Tuple[str, ...], filtered_pop: pd.DataFrame
) -> DataFrameGroupBy:
# NOTE: It's a bit hacky how we are handling the groupby object if there
# are no stratifications. The alternative is to use the entire population
# instead of a groupby object, but then we would need to handle
# the different ways the aggregator can behave.

return (
filtered_pop.groupby(list(stratifications), observed=False)
if list(stratifications)
else filtered_pop.groupby(lambda _: "all")
)
"""Group the population by stratifications.
NOTE: Stratifications at this point can be an empty tuple.
HACK: It's a bit hacky how we are handling the groupby object if there
are no stratifications. The alternative is to use the entire population
instead of a groupby object, but then we would need to handle
the different ways the aggregator can behave.
"""

if stratifications:
pop_groups = filtered_pop.groupby(
[get_mapped_col_name(stratification) for stratification in stratifications],
observed=False,
)
else:
pop_groups = filtered_pop.groupby(lambda _: "all")
return pop_groups

def _rename_stratification_columns(self, results: pd.DataFrame) -> None:
"""convert stratified mapped index names to original"""
if isinstance(results.index, pd.MultiIndex):
idx_names = [get_original_col_name(name) for name in results.index.names]
results.rename_axis(index=idx_names, inplace=True)
else:
idx_name = results.index.name
if idx_name is not None:
results.index.rename(get_original_col_name(idx_name), inplace=True)
17 changes: 16 additions & 1 deletion src/vivarium/framework/results/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def register_stratification(
self,
name: str,
categories: List[str],
excluded_categories: Optional[List[str]] = None,
mapper: Optional[Callable[[pd.DataFrame], pd.Series[str]]] = None,
is_vectorized: bool = False,
requires_columns: List[str] = [],
Expand All @@ -86,6 +87,9 @@ def register_stratification(
Name of the of the column created by the stratification.
categories
List of string values that the mapper is allowed to output.
excluded_categories
List of mapped string values to be excluded from results processing.
If None (the default), will use exclusions as defined in the configuration.
mapper
A callable that emits values in `categories` given inputs from columns
and values in the `requires_columns` and `requires_values`, respectively.
Expand All @@ -107,6 +111,7 @@ def register_stratification(
self._manager.register_stratification(
name,
categories,
excluded_categories,
mapper,
is_vectorized,
requires_columns,
Expand All @@ -119,6 +124,7 @@ def register_binned_stratification(
binned_column: str,
bin_edges: List[Union[int, float]] = [],
labels: List[str] = [],
excluded_categories: Optional[List[str]] = None,
target_type: str = "column",
**cut_kwargs: Dict,
) -> None:
Expand All @@ -136,6 +142,9 @@ def register_binned_stratification(
labels
List of string labels for bins. The length must be equal to the length
of `bin_edges` minus 1.
excluded_categories
List of mapped string values to be excluded from results processing.
If None (the default), will use exclusions as defined in the configuration.
target_type
"column" or "value"
**cut_kwargs
Expand All @@ -146,7 +155,13 @@ def register_binned_stratification(
None
"""
self._manager.register_binned_stratification(
target, target_type, binned_column, bin_edges, labels, **cut_kwargs
target,
binned_column,
bin_edges,
labels,
excluded_categories,
target_type,
**cut_kwargs,
)

###############################
Expand Down
Loading

0 comments on commit 0bd9a2e

Please # to comment.