From 0bd9a2e96b01c1e5061191e1b8a5bf301feebe7b Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:10:45 -0600 Subject: [PATCH] Feature/sbachmei/mic 5163 exclude unwanted results (#460) * implement exclusions * handle name collisions when stratifying * allow component_type to be any sequence * Add tests for stratification registration through interface --- docs/source/tutorials/exploration.rst | 1 + src/vivarium/framework/components/manager.py | 12 +- src/vivarium/framework/results/context.py | 137 ++++++++-- src/vivarium/framework/results/interface.py | 17 +- src/vivarium/framework/results/manager.py | 22 +- src/vivarium/framework/results/observation.py | 8 +- .../framework/results/stratification.py | 64 +++-- tests/framework/results/helpers.py | 23 +- tests/framework/results/test_context.py | 233 ++++++++++++------ tests/framework/results/test_interface.py | 122 ++++++++- tests/framework/results/test_manager.py | 120 +++++---- tests/framework/results/test_observation.py | 21 +- .../framework/results/test_stratification.py | 169 +++++++------ 13 files changed, 671 insertions(+), 278 deletions(-) diff --git a/docs/source/tutorials/exploration.rst b/docs/source/tutorials/exploration.rst index bce3cf917..ae72caa63 100644 --- a/docs/source/tutorials/exploration.rst +++ b/docs/source/tutorials/exploration.rst @@ -95,6 +95,7 @@ configuration by simply printing it. sim = get_disease_model_simulation() del sim.configuration['input_data'] + del sim.configuration['stratification']['excluded_categories'] .. testcode:: configuration diff --git a/src/vivarium/framework/components/manager.py b/src/vivarium/framework/components/manager.py index b4968951f..3af646e82 100644 --- a/src/vivarium/framework/components/manager.py +++ b/src/vivarium/framework/components/manager.py @@ -17,8 +17,7 @@ """ import inspect -import typing -from typing import Any, Dict, Iterator, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Sequence, Tuple, Union from layered_config_tree import ( ConfigurationError, @@ -31,7 +30,7 @@ from vivarium.framework.lifecycle import LifeCycleManager from vivarium.manager import Manager -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from vivarium.framework.engine import Builder @@ -181,7 +180,7 @@ def add_components(self, components: Union[List[Component], Tuple[Component]]) - self._components.add(c) def get_components_by_type( - self, component_type: Union[type, Tuple[type, ...]] + self, component_type: Union[type, Sequence[type]] ) -> List[Component]: """Get all components that are an instance of ``component_type``. @@ -196,7 +195,8 @@ def get_components_by_type( A list of components of type ``component_type``. """ - return [c for c in self._components if isinstance(c, component_type)] + # Convert component_type to a tuple for isinstance + return [c for c in self._components if isinstance(c, tuple(component_type))] def get_component(self, name: str) -> Component: """Get the component with name ``name``. @@ -348,7 +348,7 @@ def get_component(self, name: str) -> Component: return self._manager.get_component(name) def get_components_by_type( - self, component_type: Union[type, Tuple[type, ...]] + self, component_type: Union[type, Sequence[type]] ) -> List[Component]: """Get all components that are an instance of ``component_type``. diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 5128c25a1..99562fabf 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -10,7 +10,11 @@ from vivarium.framework.event import Event from vivarium.framework.results.exceptions import ResultsConfigurationError from vivarium.framework.results.observation import BaseObservation -from vivarium.framework.results.stratification import Stratification +from vivarium.framework.results.stratification import ( + Stratification, + get_mapped_col_name, + get_original_col_name, +) class ResultsContext: @@ -25,6 +29,7 @@ class ResultsContext: def __init__(self) -> None: self.default_stratifications: List[str] = [] self.stratifications: List[Stratification] = [] + self.excluded_categories: dict[str, list[str]] = {} # keys are event names: [ # "time_step__prepare", # "time_step", @@ -42,6 +47,9 @@ def name(self) -> str: def setup(self, builder: Builder) -> None: self.logger = builder.logging.get_logger(self.name) + self.excluded_categories = ( + builder.configuration.stratification.excluded_categories.to_dict() + ) # noinspection PyAttributeOutsideInit def set_default_stratifications(self, default_grouping_columns: List[str]) -> None: @@ -57,6 +65,7 @@ def add_stratification( name: str, sources: List[str], categories: List[str], + excluded_categories: Optional[List[str]], mapper: Optional[Callable[[Union[pd.Series[str], pd.DataFrame]], pd.Series[str]]], is_vectorized: bool, ) -> None: @@ -71,6 +80,9 @@ def add_stratification( categorization. categories List of string values that the mapper is allowed to output. + excluded_categories + List of mapped string values to be excluded from results processing. + If None (the default), will use exclusions as defined in the configuration. mapper A callable that emits values in `categories` given inputs from columns and values in the `requires_columns` and `requires_values`, respectively. @@ -100,7 +112,35 @@ def add_stratification( raise ValueError( f"Found duplicate categories in stratification '{name}': {categories}." ) - stratification = Stratification(name, sources, categories, mapper, is_vectorized) + + # Handle excluded categories. If excluded_categories are explicitly + # passed in, we use that instead of what is in the model spec. + to_exclude = ( + excluded_categories + if excluded_categories is not None + else self.excluded_categories.get(name, []) + ) + unknown_exclusions = set(to_exclude) - set(categories) + if len(unknown_exclusions) > 0: + raise ValueError( + f"Excluded categories {unknown_exclusions} not found in categories " + f"{categories} for stratification '{name}'." + ) + if to_exclude: + self.logger.debug( + f"'{name}' has category exclusion requests: {to_exclude}\n" + "Removing these from the allowable categories." + ) + categories = [category for category in categories if category not in to_exclude] + + stratification = Stratification( + name, + sources, + categories, + to_exclude, + mapper, + is_vectorized, + ) self.stratifications.append(stratification) def register_observation( @@ -166,44 +206,91 @@ def gather_results( None, None, ]: - # Optimization: We store all the producers by pop_filter and stratifications - # so that we only have to apply them once each time we compute results. + """Generate current results for all observations at this lifecycle phase and event.""" + for stratification in self.stratifications: - population = stratification(population) + # Add new columns of mapped values to the population to prevent name collisions + new_column = get_mapped_col_name(stratification.name) + if new_column in population.columns: + raise ValueError( + f"Stratification column '{new_column}' " + "already exists in the state table or as a pipeline which is a required " + "name for stratifying results - choose a different name." + ) + population[new_column] = stratification(population) - for (pop_filter, stratifications), observations in self.observations[ + # Optimization: We store all the producers by pop_filter and stratifications + # so that we only have to apply them once each time we compute results. + for (pop_filter, stratification_names), observations in self.observations[ lifecycle_phase ].items(): # Results production can be simplified to # filter -> groupby -> aggregate in all situations we've seen. - filtered_pop = self._filter_population(population, pop_filter) + filtered_pop = self._filter_population( + population, pop_filter, stratification_names + ) if filtered_pop.empty: yield None, None, None else: - if stratifications is None: + if stratification_names is None: pop = filtered_pop else: - pop = self._get_groups(stratifications, filtered_pop) + pop = self._get_groups(stratification_names, filtered_pop) for observation in observations: - yield observation.observe( - event, pop, stratifications - ), observation.name, observation.results_updater + results = observation.observe(event, pop, stratification_names) + if results is not None: + self._rename_stratification_columns(results) - @staticmethod - def _filter_population(population: pd.DataFrame, pop_filter: str) -> pd.DataFrame: - return population.query(pop_filter) if pop_filter else population + yield (results, observation.name, observation.results_updater) + + def _filter_population( + self, + population: pd.DataFrame, + pop_filter: str, + stratification_names: Optional[tuple[str, ...]], + ) -> pd.DataFrame: + """Filter the population based on the filter string as well as any + excluded stratification categories + """ + pop = population.query(pop_filter) if pop_filter else population.copy() + if stratification_names: + # Drop all rows in the mapped_stratification columns that have NaN values + # (which only exist if the mapper returned an excluded category). + pop = pop.dropna( + subset=[ + get_mapped_col_name(stratification) + for stratification in stratification_names + ] + ) + return pop @staticmethod def _get_groups( stratifications: Tuple[str, ...], filtered_pop: pd.DataFrame ) -> DataFrameGroupBy: - # NOTE: It's a bit hacky how we are handling the groupby object if there - # are no stratifications. The alternative is to use the entire population - # instead of a groupby object, but then we would need to handle - # the different ways the aggregator can behave. - - return ( - filtered_pop.groupby(list(stratifications), observed=False) - if list(stratifications) - else filtered_pop.groupby(lambda _: "all") - ) + """Group the population by stratifications. + NOTE: Stratifications at this point can be an empty tuple. + HACK: It's a bit hacky how we are handling the groupby object if there + are no stratifications. The alternative is to use the entire population + instead of a groupby object, but then we would need to handle + the different ways the aggregator can behave. + """ + + if stratifications: + pop_groups = filtered_pop.groupby( + [get_mapped_col_name(stratification) for stratification in stratifications], + observed=False, + ) + else: + pop_groups = filtered_pop.groupby(lambda _: "all") + return pop_groups + + def _rename_stratification_columns(self, results: pd.DataFrame) -> None: + """convert stratified mapped index names to original""" + if isinstance(results.index, pd.MultiIndex): + idx_names = [get_original_col_name(name) for name in results.index.names] + results.rename_axis(index=idx_names, inplace=True) + else: + idx_name = results.index.name + if idx_name is not None: + results.index.rename(get_original_col_name(idx_name), inplace=True) diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index e251271cf..228bcdefc 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -73,6 +73,7 @@ def register_stratification( self, name: str, categories: List[str], + excluded_categories: Optional[List[str]] = None, mapper: Optional[Callable[[pd.DataFrame], pd.Series[str]]] = None, is_vectorized: bool = False, requires_columns: List[str] = [], @@ -86,6 +87,9 @@ def register_stratification( Name of the of the column created by the stratification. categories List of string values that the mapper is allowed to output. + excluded_categories + List of mapped string values to be excluded from results processing. + If None (the default), will use exclusions as defined in the configuration. mapper A callable that emits values in `categories` given inputs from columns and values in the `requires_columns` and `requires_values`, respectively. @@ -107,6 +111,7 @@ def register_stratification( self._manager.register_stratification( name, categories, + excluded_categories, mapper, is_vectorized, requires_columns, @@ -119,6 +124,7 @@ def register_binned_stratification( binned_column: str, bin_edges: List[Union[int, float]] = [], labels: List[str] = [], + excluded_categories: Optional[List[str]] = None, target_type: str = "column", **cut_kwargs: Dict, ) -> None: @@ -136,6 +142,9 @@ def register_binned_stratification( labels List of string labels for bins. The length must be equal to the length of `bin_edges` minus 1. + excluded_categories + List of mapped string values to be excluded from results processing. + If None (the default), will use exclusions as defined in the configuration. target_type "column" or "value" **cut_kwargs @@ -146,7 +155,13 @@ def register_binned_stratification( None """ self._manager.register_binned_stratification( - target, target_type, binned_column, bin_edges, labels, **cut_kwargs + target, + binned_column, + bin_edges, + labels, + excluded_categories, + target_type, + **cut_kwargs, ) ############################### diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index b07d04106..acc012c5b 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -35,6 +35,7 @@ class ResultsManager(Manager): CONFIGURATION_DEFAULTS = { "stratification": { "default": [], + "excluded_categories": {}, } } @@ -71,6 +72,8 @@ def get_results(self) -> Dict[str, pd.DataFrame]: # noinspection PyAttributeOutsideInit def setup(self, builder: "Builder") -> None: + self._results_context.setup(builder) + self.logger = builder.logging.get_logger(self.name) self.population_view = builder.population.get_view([]) self.clock = builder.time.clock() @@ -138,6 +141,8 @@ def gather_results(self, lifecycle_phase: str, event: Event) -> None: with 0.0. """ population = self._prepare_population(event) + if population.empty: + return for results_group, measure, updater in self._results_context.gather_results( population, lifecycle_phase, event ): @@ -158,6 +163,7 @@ def register_stratification( self, name: str, categories: List[str], + excluded_categories: Optional[List[str]], mapper: Optional[Callable[[Union[pd.Series[str], pd.DataFrame]], pd.Series[str]]], is_vectorized: bool, requires_columns: List[str] = [], @@ -171,6 +177,9 @@ def register_stratification( Name of the of the column created by the stratification. categories List of string values that the mapper is allowed to output. + excluded_categories + List of mapped string values to be excluded from results processing. + If None (the default), will use exclusions as defined in the configuration. mapper A callable that emits values in `categories` given inputs from columns and values in the `requires_columns` and `requires_values`, respectively. @@ -192,7 +201,7 @@ def register_stratification( self.logger.debug(f"Registering stratification {name}") target_columns = list(requires_columns) + list(requires_values) self._results_context.add_stratification( - name, target_columns, categories, mapper, is_vectorized + name, target_columns, categories, excluded_categories, mapper, is_vectorized ) self._add_resources(requires_columns, SourceType.COLUMN) self._add_resources(requires_values, SourceType.VALUE) @@ -200,10 +209,11 @@ def register_stratification( def register_binned_stratification( self, target: str, - target_type: str, binned_column: str, bin_edges: List[Union[int, float]], labels: List[str], + excluded_categories: Optional[List[str]], + target_type: str, **cut_kwargs, ) -> None: """Manager-level registration of a continuous `target` quantity to observe into bins in a `binned_column`. @@ -212,8 +222,6 @@ def register_binned_stratification( ---------- target String name of the state table column or value pipeline used to stratify. - target_type - "column" or "value" binned_column String name of the column for the binned quantities. bin_edges @@ -224,6 +232,11 @@ def register_binned_stratification( labels List of string labels for bins. The length must equal to the length of `bin_edges` minus one. + excluded_categories + List of mapped string values to be excluded from results processing. + If None (the default), will use exclusions as defined in the configuration. + target_type + "column" or "value" **cut_kwargs Keyword arguments for :meth: pandas.cut. @@ -252,6 +265,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: self.register_stratification( name=binned_column, categories=labels, + excluded_categories=excluded_categories, mapper=_bin_data, is_vectorized=True, **target_kwargs, diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index f43543d5c..915d73290 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -140,7 +140,7 @@ def __init__( @staticmethod def initialize_results( requested_stratification_names: set[str], - registered_stratifications: List[Stratification], + registered_stratifications: list[Stratification], ) -> pd.DataFrame: """Initialize a dataframe of 0s with complete set of stratifications as the index.""" @@ -188,7 +188,7 @@ def results_gatherer( @staticmethod def _aggregate( pop_groups: DataFrameGroupBy, - aggregator_sources: Optional[List[str]], + aggregator_sources: Optional[list[str]], aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], ) -> Union[pd.Series[float], pd.DataFrame]: aggregates = ( @@ -236,7 +236,7 @@ def __init__( when: str, results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], stratifications: Tuple[str, ...], - aggregator_sources: Optional[List[str]], + aggregator_sources: Optional[list[str]], aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], to_observe: Callable[[Event], bool] = lambda event: True, ): @@ -286,7 +286,7 @@ def __init__( name: str, pop_filter: str, when: str, - included_columns: List[str], + included_columns: list[str], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], to_observe: Callable[[Event], bool] = lambda event: True, ): diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index 3935aa6c7..ba8452763 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -6,6 +6,8 @@ import pandas as pd from pandas.api.types import CategoricalDtype +STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values" + @dataclass class Stratification: @@ -14,14 +16,15 @@ class Stratification: Each Stratification represents a set of mutually exclusive and collectively exhaustive categories into which simulants can be assigned. - The `Stratification` class has five fields: `name`, `sources`, `mapper`, - `categories`, and `is_vectorized`. The `name` is the name of the column - created by the mapper. The `sources` is a list of columns in the extended - state table that are the inputs to the mapper function. Simulants will - later be grouped by this column (or these columns) during stratification. - `categories` is a set of values that the mapper is allowed to output. The - `mapper` is the method that transforms the source to the name column. - The method produces an output column by calling the mapper on the source + The `Stratification` class has six fields: `name`, `sources`, `mapper`, + `categories`, `excluded_categories`, and `is_vectorized`. The `name` is the + name of the column created by the mapper. The `sources` is a list of columns + in the extended state table that are the inputs to the mapper function. Simulants + will later be grouped by this column (or these columns) during stratification. + `categories` is the total set of values that the mapper can output. + `excluded_categories` are values that have been requested to be excluded (and + already removed) from `categories`. The `mapper` is the method that transforms the source + to the name column. The method produces an output column by calling the mapper on the source columns. If the mapper is `None`, the default identity mapper is used. If the mapper is not vectorized this is performed by using `pd.apply`. Finally, `is_vectorized` is a boolean parameter that signifies whether @@ -35,6 +38,7 @@ class Stratification: name: str sources: List[str] categories: List[str] + excluded_categories: List[str] mapper: Optional[Callable[[Union[pd.Series[str], pd.DataFrame]], pd.Series[str]]] = None is_vectorized: bool = False @@ -58,21 +62,47 @@ def __post_init__(self) -> None: if not self.sources: raise ValueError("The sources argument must be non-empty.") - def __call__(self, population: pd.DataFrame) -> pd.DataFrame: + def __call__(self, population: pd.DataFrame) -> pd.Series[str]: + """Apply the mapper to the population 'sources' columns and add the result + to the population. Any excluded categories (which have already been removed + from self.categories) will be converted to NaNs in the new column + and dropped later at the observation level. + """ if self.is_vectorized: - raw_mapped_column = self.mapper(population[self.sources]) + mapped_column = self.mapper(population[self.sources]) else: - raw_mapped_column = population[self.sources].apply(self.mapper, axis=1) - mapped_column = raw_mapped_column.astype( - CategoricalDtype(categories=self.categories, ordered=True) + mapped_column = population[self.sources].apply(self.mapper, axis=1) + unknown_categories = set(mapped_column) - set( + self.categories + self.excluded_categories ) + # Reduce all nans to a single one + unknown_categories = [cat for cat in unknown_categories if not pd.isna(cat)] if mapped_column.isna().any(): - invalid_categories = set(raw_mapped_column.unique()) - set(self.categories) - raise ValueError(f"Invalid values '{invalid_categories}' found in {self.name}.") + unknown_categories.append(mapped_column[mapped_column.isna()].iat[0]) + if unknown_categories: + raise ValueError(f"Invalid values mapped to {self.name}: {unknown_categories}") - population[self.name] = mapped_column - return population + # Convert the dtype to the allowed categories. Note that this will + # result in Nans for any values in excluded_categories. + mapped_column = mapped_column.astype( + CategoricalDtype(categories=self.categories, ordered=True) + ) + return mapped_column @staticmethod def _default_mapper(pop: pd.DataFrame) -> pd.Series[str]: return pop.squeeze(axis=1) + + +def get_mapped_col_name(col_name: str) -> str: + """Return a new column name to be used for mapped values""" + return f"{col_name}_{STRATIFICATION_COLUMN_SUFFIX}" + + +def get_original_col_name(col_name: str) -> str: + """Return the original column name given a modified mapped column name.""" + return ( + col_name[: -(len(STRATIFICATION_COLUMN_SUFFIX)) - 1] + if col_name.endswith(f"_{STRATIFICATION_COLUMN_SUFFIX}") + else col_name + ) diff --git a/tests/framework/results/helpers.py b/tests/framework/results/helpers.py index 37c5fd295..6ff72e54d 100644 --- a/tests/framework/results/helpers.py +++ b/tests/framework/results/helpers.py @@ -11,11 +11,11 @@ from vivarium.framework.results.observer import Observer NAME = "hogwarts_house" -SOURCES = ["first_name", "last_name"] -CATEGORIES = ["hufflepuff", "ravenclaw", "slytherin", "gryffindor"] +NAME_COLUMNS = ["first_name", "last_name"] +HOUSE_CATEGORIES = ["hufflepuff", "ravenclaw", "slytherin", "gryffindor"] STUDENT_TABLE = pd.DataFrame( np.array([["harry", "potter"], ["severus", "snape"], ["luna", "lovegood"]]), - columns=SOURCES, + columns=NAME_COLUMNS, ) STUDENT_HOUSES = pd.Series(["gryffindor", "slytherin", "ravenclaw"]) @@ -30,7 +30,7 @@ POWER_LEVEL_BIN_EDGES = [0, 25, 50, 75, 100] POWER_LEVEL_GROUP_LABELS = ["low", "medium", "high", "very high"] TRACKED_STATUSES = [True, False] -RECORDS = list(itertools.product(CATEGORIES, FAMILIARS, POWER_LEVELS, TRACKED_STATUSES)) +RECORDS = list(itertools.product(HOUSE_CATEGORIES, FAMILIARS, POWER_LEVELS, TRACKED_STATUSES)) BASE_POPULATION = pd.DataFrame(data=RECORDS, columns=COL_NAMES) HARRY_POTTER_CONFIG = { @@ -254,17 +254,18 @@ def update_valedictorian(self, _existing_df, new_df): class HogwartsResultsStratifier(Component): def setup(self, builder: Builder) -> None: builder.results.register_stratification( - "student_house", list(STUDENT_HOUSES), requires_columns=["student_house"] + name="student_house", + categories=list(STUDENT_HOUSES), + requires_columns=["student_house"], ) builder.results.register_stratification( - "familiar", FAMILIARS, requires_columns=["familiar"] + name="familiar", categories=FAMILIARS, requires_columns=["familiar"] ) builder.results.register_binned_stratification( "power_level", "power_level_group", POWER_LEVEL_BIN_EDGES, POWER_LEVEL_GROUP_LABELS, - "column", ) @@ -286,7 +287,7 @@ def results_formatter( return results[other_cols + [VALUE_COLUMN]].sort_index().reset_index() -def sorting_hat_vector(state_table: pd.DataFrame) -> pd.Series: +def sorting_hat_vectorized(state_table: pd.DataFrame) -> pd.Series: sorted_series = state_table.apply(sorting_hat_serial, axis=1) return sorted_series @@ -309,7 +310,7 @@ def sorting_hat_bad_mapping(simulant_row: pd.Series) -> str: def verify_stratification_added( - stratification_list, name, sources, categories, mapper, is_vectorized + stratification_list, name, sources, categories, excluded_categories, mapper, is_vectorized ): """Verify that a :class: `vivarium.framework.results.stratification.Stratification` is in `stratification_list`""" matching_stratification_found = False @@ -317,7 +318,9 @@ def verify_stratification_added( # big equality check if ( stratification.name == name - and sorted(stratification.categories) == sorted(categories) + and sorted(stratification.categories) + == sorted([cat for cat in categories if cat not in excluded_categories]) + and sorted(stratification.excluded_categories) == sorted(excluded_categories) and stratification.mapper == mapper and stratification.is_vectorized == is_vectorized and sorted(stratification.sources) == sorted(sources) diff --git a/tests/framework/results/test_context.py b/tests/framework/results/test_context.py index 95bf3c9a8..34cb26cf3 100644 --- a/tests/framework/results/test_context.py +++ b/tests/framework/results/test_context.py @@ -3,18 +3,19 @@ import re from datetime import timedelta +import numpy as np import pandas as pd import pytest from pandas.core.groupby import DataFrameGroupBy from tests.framework.results.helpers import ( BASE_POPULATION, - CATEGORIES, FAMILIARS, + HOUSE_CATEGORIES, NAME, - SOURCES, + NAME_COLUMNS, sorting_hat_serial, - sorting_hat_vector, + sorting_hat_vectorized, verify_stratification_added, ) from vivarium.framework.event import Event @@ -35,86 +36,87 @@ def mocked_event(mocker) -> Event: @pytest.mark.parametrize( - "name, sources, categories, mapper, is_vectorized", + "mapper, is_vectorized", [ - (NAME, SOURCES, CATEGORIES, sorting_hat_vector, True), - (NAME, SOURCES, CATEGORIES, sorting_hat_serial, False), + (sorting_hat_vectorized, True), + (sorting_hat_serial, False), ], ids=["vectorized_mapper", "non-vectorized_mapper"], ) -def test_add_stratification(name, sources, categories, mapper, is_vectorized): +def test_add_stratification(mapper, is_vectorized, mocker): ctx = ResultsContext() + mocker.patch.object(ctx, "excluded_categories", {}) assert not verify_stratification_added( - ctx.stratifications, name, sources, categories, mapper, is_vectorized + ctx.stratifications, NAME, NAME_COLUMNS, HOUSE_CATEGORIES, [], mapper, is_vectorized + ) + ctx.add_stratification( + name=NAME, + sources=NAME_COLUMNS, + categories=HOUSE_CATEGORIES, + excluded_categories=None, + mapper=mapper, + is_vectorized=is_vectorized, ) - ctx.add_stratification(name, sources, categories, mapper, is_vectorized) assert verify_stratification_added( - ctx.stratifications, name, sources, categories, mapper, is_vectorized + ctx.stratifications, NAME, NAME_COLUMNS, HOUSE_CATEGORIES, [], mapper, is_vectorized ) @pytest.mark.parametrize( - "name, sources, categories, mapper, is_vectorized, expected_exception", + "name, categories, excluded_categories, msg_match", [ - ( # sources not in population columns + ( + "duplicate_name", + HOUSE_CATEGORIES, + [], + "Stratification name 'duplicate_name' is already used: ", + ), + ( NAME, - ["middle_initial"], - CATEGORIES, - sorting_hat_vector, - True, - TypeError, + HOUSE_CATEGORIES + ["slytherin"], + [], + f"Found duplicate categories in stratification '{NAME}': ['slytherin']", ), - ( # is_vectorized=True with non-vectorized mapper + ( NAME, - SOURCES, - CATEGORIES, - sorting_hat_serial, - True, - Exception, + HOUSE_CATEGORIES + ["gryffindor", "slytherin"], + [], + f"Found duplicate categories in stratification '{NAME}': ['gryffindor', 'slytherin']", ), - ( # is_vectorized=False with vectorized mapper + ( NAME, - SOURCES, - CATEGORIES, - sorting_hat_vector, - False, - Exception, + HOUSE_CATEGORIES, + ["gryfflepuff"], + "Excluded categories {'gryfflepuff'} not found in categories", ), ], -) -def test_add_stratification_raises( - name, sources, categories, mapper, is_vectorized, expected_exception -): - ctx = ResultsContext() - with pytest.raises(expected_exception): - raise ctx.add_stratification(name, sources, categories, mapper, is_vectorized) - - -def test_add_stratifcation_duplicate_name_raises(): - ctx = ResultsContext() - ctx.add_stratification(NAME, SOURCES, CATEGORIES, sorting_hat_vector, True) - with pytest.raises(ValueError, match=f"Stratification name '{NAME}' is already used: "): - # register a different stratification but w/ the same name - ctx.add_stratification(NAME, [], [], None, False) - - -@pytest.mark.parametrize( - "duplicates", - [ - ["slytherin"], - ["gryffindor", "slytherin"], + ids=[ + "duplicate_name", + "duplicate_category", + "duplicate_categories", + "unknown_excluded_category", ], ) -def test_add_stratification_duplicate_category_raises(duplicates): +def test_add_stratification_raises(name, categories, excluded_categories, msg_match, mocker): ctx = ResultsContext() - with pytest.raises( - ValueError, - match=re.escape( - f"Found duplicate categories in stratification '{NAME}': {duplicates}" - ), - ): + mocker.patch.object(ctx, "excluded_categories", {name: excluded_categories}) + # Register a stratification to test against duplicate stratifications + ctx.add_stratification( + name="duplicate_name", + sources=["foo"], + categories=["bar"], + excluded_categories=None, + mapper=sorting_hat_serial, + is_vectorized=False, + ) + with pytest.raises(ValueError, match=re.escape(msg_match)): ctx.add_stratification( - NAME, SOURCES, CATEGORIES + duplicates, sorting_hat_vector, True + name=name, + sources=NAME_COLUMNS, + categories=categories, + excluded_categories=excluded_categories, + mapper=sorting_hat_vectorized, + is_vectorized=True, ) @@ -225,9 +227,23 @@ def test_adding_observation_gather_results( # Set up stratifications if "house" in stratifications: - ctx.add_stratification("house", ["house"], CATEGORIES, None, True) + ctx.add_stratification( + name="house", + sources=["house"], + categories=HOUSE_CATEGORIES, + excluded_categories=None, + mapper=None, + is_vectorized=True, + ) if "familiar" in stratifications: - ctx.add_stratification("familiar", ["familiar"], FAMILIARS, None, True) + ctx.add_stratification( + name="familiar", + sources=["familiar"], + categories=FAMILIARS, + excluded_categories=None, + mapper=None, + is_vectorized=True, + ) ctx.register_observation( observation_type=AddingObservation, name="foo", @@ -355,9 +371,23 @@ def test_gather_results_partial_stratifications_in_results( # Set up stratifications if "house" in stratifications: - ctx.add_stratification("house", ["house"], CATEGORIES, None, True) + ctx.add_stratification( + name="house", + sources=["house"], + categories=HOUSE_CATEGORIES, + excluded_categories=None, + mapper=None, + is_vectorized=True, + ) if "familiar" in stratifications: - ctx.add_stratification("familiar", ["familiar"], FAMILIARS, None, True) + ctx.add_stratification( + name="familiar", + sources=["familiar"], + categories=FAMILIARS, + excluded_categories=None, + mapper=None, + is_vectorized=True, + ) ctx.register_observation( observation_type=AddingObservation, @@ -448,8 +478,22 @@ def test_bad_aggregator_stratification(mocked_event): lifecycle_phase = "collect_metrics" # Set up stratifications - ctx.add_stratification("house", ["house"], CATEGORIES, None, True) - ctx.add_stratification("familiar", ["familiar"], FAMILIARS, None, True) + ctx.add_stratification( + name="house", + sources=["house"], + categories=HOUSE_CATEGORIES, + excluded_categories=None, + mapper=None, + is_vectorized=True, + ) + ctx.add_stratification( + name="familiar", + sources=["familiar"], + categories=FAMILIARS, + excluded_categories=None, + mapper=None, + is_vectorized=True, + ) ctx.register_observation( observation_type=AddingObservation, name="this_shouldnt_work", @@ -469,41 +513,70 @@ def test_bad_aggregator_stratification(mocked_event): @pytest.mark.parametrize( - "pop_filter", + "pop_filter, stratifications", [ - 'familiar=="spaghetti_yeti"', - 'familiar=="cat"', - "", + ('familiar=="cat"', tuple()), + ('familiar=="spaghetti_yeti"', tuple()), + ("", ("new_col1",)), + ("", ("new_col1", "new_col2")), + ('familiar=="cat"', ("new_col1",)), + ("", tuple()), + ], + ids=[ + "pop_filter", + "pop_filter_empties_dataframe", + "single_excluded_stratification", + "two_excluded_stratifications", + "pop_filter_and_excluded_stratification", + "no_pop_filter_or_excluded_stratifications", ], ) -def test__filter_population(pop_filter): +def test__filter_population(pop_filter, stratifications): + population = BASE_POPULATION.copy() + if stratifications: + # Make some of the stratifications missing to mimic mapping to excluded categories + population["new_col1"] = "new_value1" + population.loc[population["tracked"] == True, "new_col1"] = np.nan + if len(stratifications) == 2: + population["new_col2"] = "new_value2" + population.loc[population["new_col1"].notna(), "new_col2"] = np.nan + # Add on the post-stratified columns + for stratification in stratifications: + mapped_col = f"{stratification}_mapped_values" + population[mapped_col] = population[stratification] + filtered_pop = ResultsContext()._filter_population( - population=BASE_POPULATION, pop_filter=pop_filter + population=population, pop_filter=pop_filter, stratification_names=stratifications ) + expected = population.copy() if pop_filter: familiar = pop_filter.split("==")[1].strip('"') - assert filtered_pop.equals(BASE_POPULATION[BASE_POPULATION["familiar"] == familiar]) - if not familiar in filtered_pop["familiar"].values: - assert filtered_pop.empty - else: - # An empty pop filter should return the entire population - assert filtered_pop.equals(BASE_POPULATION) + expected = expected[expected["familiar"] == familiar] + for stratification in stratifications: + expected = expected[expected[stratification].notna()] + assert filtered_pop.equals(expected) @pytest.mark.parametrize( "stratifications, values", [ (("familiar",), [FAMILIARS]), - (("familiar", "house"), [FAMILIARS, CATEGORIES]), + (("familiar", "house"), [FAMILIARS, HOUSE_CATEGORIES]), ((), "foo"), ], ) def test__get_groups(stratifications, values): + + filtered_pop = BASE_POPULATION.copy() + # Generate the post-stratified columns + for stratification in stratifications: + mapped_col = f"{stratification}_mapped_values" + filtered_pop[mapped_col] = filtered_pop[stratification] groups = ResultsContext()._get_groups( - stratifications=stratifications, filtered_pop=BASE_POPULATION + stratifications=stratifications, filtered_pop=filtered_pop ) assert isinstance(groups, DataFrameGroupBy) - if len(stratifications) > 0: + if stratifications: combinations = set(itertools.product(*values)) if len(values) == 1: # convert from set of tuples to set of strings diff --git a/tests/framework/results/test_interface.py b/tests/framework/results/test_interface.py index 5876c2de4..fe5036ead 100644 --- a/tests/framework/results/test_interface.py +++ b/tests/framework/results/test_interface.py @@ -4,10 +4,12 @@ import pandas as pd import pytest +from layered_config_tree import LayeredConfigTree +from loguru import logger -from tests.framework.results.helpers import BASE_POPULATION -from tests.framework.results.helpers import CATEGORIES as HOUSES -from tests.framework.results.helpers import FAMILIARS, mock_get_value +from tests.framework.results.helpers import BASE_POPULATION, FAMILIARS +from tests.framework.results.helpers import HOUSE_CATEGORIES as HOUSES +from tests.framework.results.helpers import mock_get_value from vivarium.framework.results import ResultsInterface, ResultsManager @@ -15,6 +17,107 @@ def _silly_aggregator(_: pd.DataFrame) -> float: return 1.0 +#################################### +# Test stratification registration # +#################################### + + +def test_register_stratification(mocker): + def _silly_mapper(): + # NOTE: it does not actually matter what this mapper returns for this test + return {"some-category", "some-other-category", "some-unwanted-category"} + + builder = mocker.Mock() + # Set up mock builder with mocked get_value call for Pipelines + mocker.patch.object(builder, "value.get_value") + builder.value.get_value = MethodType(mock_get_value, builder) + mgr = ResultsManager() + mgr.setup(builder) + interface = ResultsInterface(mgr) + + # Check pre-registration stratifications and manager required columns/values + assert len(mgr._results_context.stratifications) == 0 + assert mgr._required_columns == {"tracked"} + assert len(mgr._required_values) == 0 + + interface.register_stratification( + name="some-name", + categories=["some-category", "some-other-category", "some-unwanted-category"], + excluded_categories=["some-unwanted-category"], + mapper=_silly_mapper, + is_vectorized=False, + requires_columns=["some-column", "some-other-column"], + requires_values=["some-value", "some-other-value"], + ) + + # Check that manager required columns/values have been updated + assert mgr._required_columns == {"tracked", "some-column", "some-other-column"} + assert mgr._required_values == {"some-value", "some-other-value"} + + # Check stratification registration + stratifications = mgr._results_context.stratifications + assert len(stratifications) == 1 + stratification = stratifications[0] + assert stratification.name == "some-name" + assert stratification.sources == [ + "some-column", + "some-other-column", + "some-value", + "some-other-value", + ] + assert stratification.categories == ["some-category", "some-other-category"] + assert stratification.excluded_categories == ["some-unwanted-category"] + assert stratification.mapper == _silly_mapper + assert stratification.is_vectorized is False + + +def test_register_binned_stratification(mocker): + + mgr = ResultsManager() + mgr.logger = logger + builder = mocker.Mock() + mgr._results_context.setup(builder) + + # Check pre-registration stratifications and manager required columns/values + assert len(mgr._results_context.stratifications) == 0 + assert mgr._required_columns == {"tracked"} + assert len(mgr._required_values) == 0 + + mgr.register_binned_stratification( + target="some-column-to-bin", + binned_column="new-binned-column", + bin_edges=[1, 2, 3], + labels=["1_to_2", "2_to_3"], + excluded_categories=["2_to_3"], + target_type="column", + some_kwarg="some-kwarg", + some_other_kwarg="some-other-kwarg", + ) + + # Check that manager required columns/values have been updated + assert mgr._required_columns == {"tracked", "some-column-to-bin"} + assert len(mgr._required_values) == 0 + + # Check stratification registration + stratifications = mgr._results_context.stratifications + assert len(stratifications) == 1 + stratification = stratifications[0] + assert stratification.name == "new-binned-column" + assert stratification.sources == ["some-column-to-bin"] + assert stratification.categories == ["1_to_2"] + assert stratification.excluded_categories == ["2_to_3"] + # Cannot access the mapper because it's in local scope, so check __repr__ + assert "function ResultsManager.register_binned_stratification.._bin_data" in str( + stratification.mapper + ) + assert stratification.is_vectorized is True + + +################################# +# Test observation registration # +################################# + + @pytest.mark.parametrize( ("obs_type", "missing_args"), [ @@ -271,13 +374,20 @@ def test_register_adding_observation_when_options(when, mocker): mgr = ResultsManager() results_interface = ResultsInterface(mgr) builder = mocker.Mock() - builder.configuration.stratification.default = [] + builder.configuration.stratification = LayeredConfigTree( + {"default": [], "excluded_categories": {}} + ) mgr.setup(builder) # register stratifications - results_interface.register_stratification("house", HOUSES, None, True, ["house"], []) results_interface.register_stratification( - "familiar", FAMILIARS, None, True, ["familiar"], [] + name="house", categories=HOUSES, is_vectorized=True, requires_columns=["house"] + ) + results_interface.register_stratification( + name="familiar", + categories=FAMILIARS, + is_vectorized=True, + requires_columns=["familiar"], ) time_step__prepare_mock_aggregator = mocker.Mock(side_effect=lambda x: 1.0) diff --git a/tests/framework/results/test_manager.py b/tests/framework/results/test_manager.py index 1e3e0191e..f80fa8d80 100644 --- a/tests/framework/results/test_manager.py +++ b/tests/framework/results/test_manager.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import pytest +from layered_config_tree import LayeredConfigTree from loguru import logger from pandas.api.types import CategoricalDtype @@ -12,12 +13,12 @@ BIN_LABELS, BIN_SILLY_BIN_EDGES, BIN_SOURCE, - CATEGORIES, FAMILIARS, HARRY_POTTER_CONFIG, + HOUSE_CATEGORIES, NAME, + NAME_COLUMNS, POWER_LEVEL_GROUP_LABELS, - SOURCES, STUDENT_HOUSES, CatBombObserver, ExamScoreObserver, @@ -31,7 +32,7 @@ ValedictorianObserver, mock_get_value, sorting_hat_serial, - sorting_hat_vector, + sorting_hat_vectorized, verify_stratification_added, ) from vivarium.framework.results import VALUE_COLUMN @@ -89,32 +90,52 @@ def test__get_stratifications( @pytest.mark.parametrize( - "name, sources, categories, mapper, is_vectorized", + "name, sources, categories, excluded_categories, mapper, is_vectorized", [ ( NAME, - SOURCES, - CATEGORIES, - sorting_hat_vector, + NAME_COLUMNS, + HOUSE_CATEGORIES, + [], + sorting_hat_vectorized, True, ), ( NAME, - SOURCES, - CATEGORIES, + NAME_COLUMNS, + HOUSE_CATEGORIES, + [], sorting_hat_serial, False, ), + ( + NAME, + NAME_COLUMNS, + HOUSE_CATEGORIES, + ["gryffindor"], + sorting_hat_vectorized, + True, + ), ], - ids=["vectorized_mapper", "non-vectorized_mapper"], + ids=["vectorized_mapper", "non-vectorized_mapper", "excluded_categories"], ) def test_register_stratification_no_pipelines( - name, sources, categories, mapper, is_vectorized, mocker + name, sources, categories, excluded_categories, mapper, is_vectorized, mocker ): mgr = ResultsManager() builder = mocker.Mock() + builder.configuration.stratification = LayeredConfigTree( + {"default": [], "excluded_categories": {}} + ) mgr.setup(builder) - mgr.register_stratification(name, categories, mapper, is_vectorized, sources, []) + mgr.register_stratification( + name=name, + categories=categories, + excluded_categories=excluded_categories, + mapper=mapper, + is_vectorized=is_vectorized, + requires_columns=sources, + ) for item in sources: assert item in mgr._required_columns assert verify_stratification_added( @@ -122,6 +143,7 @@ def test_register_stratification_no_pipelines( name, sources, categories, + excluded_categories, mapper, is_vectorized, ) @@ -132,15 +154,15 @@ def test_register_stratification_no_pipelines( [ ( NAME, - SOURCES, - CATEGORIES, - sorting_hat_vector, + NAME_COLUMNS, + HOUSE_CATEGORIES, + sorting_hat_vectorized, True, ), ( NAME, - SOURCES, - CATEGORIES, + NAME_COLUMNS, + HOUSE_CATEGORIES, sorting_hat_serial, False, ), @@ -152,11 +174,22 @@ def test_register_stratification_with_pipelines( ): mgr = ResultsManager() builder = mocker.Mock() + builder.configuration.stratification = LayeredConfigTree( + {"default": [], "excluded_categories": {}} + ) # Set up mock builder with mocked get_value call for Pipelines mocker.patch.object(builder, "value.get_value") builder.value.get_value = MethodType(mock_get_value, builder) mgr.setup(builder) - mgr.register_stratification(name, categories, mapper, is_vectorized, [], sources) + mgr.register_stratification( + name=name, + categories=categories, + excluded_categories=None, + mapper=mapper, + is_vectorized=is_vectorized, + requires_columns=[], + requires_values=sources, + ) for item in sources: assert item in mgr._required_values assert verify_stratification_added( @@ -164,6 +197,7 @@ def test_register_stratification_with_pipelines( name, sources, categories, + [], mapper, is_vectorized, ) @@ -174,15 +208,15 @@ def test_register_stratification_with_pipelines( [ ( # expected Stratification for vectorized NAME, - SOURCES, - CATEGORIES, - sorting_hat_vector, + NAME_COLUMNS, + HOUSE_CATEGORIES, + sorting_hat_vectorized, True, ), ( # expected Stratification for non-vectorized NAME, - SOURCES, - CATEGORIES, + NAME_COLUMNS, + HOUSE_CATEGORIES, sorting_hat_serial, False, ), @@ -194,13 +228,22 @@ def test_register_stratification_with_column_and_pipelines( ): mgr = ResultsManager() builder = mocker.Mock() + builder.configuration.stratification = LayeredConfigTree( + {"default": [], "excluded_categories": {}} + ) # Set up mock builder with mocked get_value call for Pipelines mocker.patch.object(builder, "value.get_value") builder.value.get_value = MethodType(mock_get_value, builder) mgr.setup(builder) mocked_column_name = "silly_column" mgr.register_stratification( - name, categories, mapper, is_vectorized, [mocked_column_name], sources + name=name, + categories=categories, + excluded_categories=None, + mapper=mapper, + is_vectorized=is_vectorized, + requires_columns=[mocked_column_name], + requires_values=sources, ) assert mocked_column_name in mgr._required_columns for item in sources: @@ -212,6 +255,7 @@ def test_register_stratification_with_column_and_pipelines( name, all_sources, categories, + [], mapper, is_vectorized, ) @@ -222,28 +266,6 @@ def test_register_stratification_with_column_and_pipelines( ############################################## -def test_register_binned_stratification(): - mgr = ResultsManager() - mgr.logger = logger - assert len(mgr._results_context.stratifications) == 0 - mgr.register_binned_stratification( - target=BIN_SOURCE, - target_type="column", - binned_column=BIN_BINNED_COLUMN, - bin_edges=BIN_SILLY_BIN_EDGES, - labels=BIN_LABELS, - ) - assert len(mgr._results_context.stratifications) == 1 - strat = mgr._results_context.stratifications[0] - assert strat.name == BIN_BINNED_COLUMN - assert strat.sources == [BIN_SOURCE] - assert strat.categories == BIN_LABELS - # Cannot access the mapper because it's in local scope, so check __repr__ - assert "function ResultsManager.register_binned_stratification.._bin_data" in str( - strat.mapper - ) - - @pytest.mark.parametrize( "bins, labels", [(BIN_SILLY_BIN_EDGES, BIN_LABELS[1:]), (BIN_SILLY_BIN_EDGES[1:], BIN_LABELS)], @@ -257,10 +279,11 @@ def test_register_binned_stratification_raises_bins_labels_mismatch(bins, labels ): mgr.register_binned_stratification( target=BIN_SOURCE, - target_type="column", binned_column=BIN_BINNED_COLUMN, bin_edges=bins, labels=labels, + excluded_categories=None, + target_type="column", ) @@ -269,10 +292,11 @@ def test_binned_stratification_mapper(): mgr.logger = logger mgr.register_binned_stratification( target=BIN_SOURCE, - target_type="column", binned_column=BIN_BINNED_COLUMN, bin_edges=BIN_SILLY_BIN_EDGES, labels=BIN_LABELS, + excluded_categories=None, + target_type="column", ) strat = mgr._results_context.stratifications[0] data = pd.Series([-np.inf] + BIN_SILLY_BIN_EDGES + [np.inf]) diff --git a/tests/framework/results/test_observation.py b/tests/framework/results/test_observation.py index 87e0bea2a..f5dcc578e 100644 --- a/tests/framework/results/test_observation.py +++ b/tests/framework/results/test_observation.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from tests.framework.results.helpers import BASE_POPULATION, CATEGORIES, FAMILIARS +from tests.framework.results.helpers import BASE_POPULATION, FAMILIARS, HOUSE_CATEGORIES from vivarium.framework.results import VALUE_COLUMN from vivarium.framework.results.context import ResultsContext from vivarium.framework.results.observation import ( @@ -63,8 +63,13 @@ def test_stratified_observation__aggregate( - If no aggregator_resources are provided, then we want a full aggregation of the groups. - _aggregate can return either a pd.Series or a pd.DataFrame of any number of columns """ + + filtered_pop = BASE_POPULATION.copy() + for stratification in stratifications: + mapped_col = f"{stratification}_mapped_values" + filtered_pop[mapped_col] = filtered_pop[stratification] groups = ResultsContext()._get_groups( - stratifications=stratifications, filtered_pop=BASE_POPULATION + stratifications=stratifications, filtered_pop=filtered_pop ) aggregates = stratified_observation._aggregate( pop_groups=groups, @@ -74,7 +79,7 @@ def test_stratified_observation__aggregate( if aggregator == len: if stratifications: stratification_idx = ( - set(itertools.product(*(FAMILIARS, CATEGORIES))) + set(itertools.product(*(FAMILIARS, HOUSE_CATEGORIES))) if "house" in stratifications else set(FAMILIARS) ) @@ -88,7 +93,7 @@ def test_stratified_observation__aggregate( expected = BASE_POPULATION[["power_level", "tracked"]].sum() / groups.ngroups if stratifications: stratification_idx = ( - set(itertools.product(*(FAMILIARS, CATEGORIES))) + set(itertools.product(*(FAMILIARS, HOUSE_CATEGORIES))) if "house" in stratifications else set(FAMILIARS) ) @@ -166,10 +171,16 @@ def test_stratified_observation__expand_index(aggregates, stratified_observation ) def test_stratified_observation_results_gatherer(stratifications, stratified_observation): ctx = ResultsContext() + # Append the post-stratified columns + filtered_population = BASE_POPULATION.copy() + for stratification in stratifications: + mapped_col = f"{stratification}_mapped_values" + filtered_population[mapped_col] = filtered_population[stratification] pop_groups = ctx._get_groups( - stratifications=stratifications, filtered_pop=BASE_POPULATION + stratifications=stratifications, filtered_pop=filtered_population ) df = stratified_observation.results_gatherer(pop_groups, stratifications) + ctx._rename_stratification_columns(df) assert set(df.columns) == set(["value"]) expected_idx_names = ( list(stratifications) if len(stratifications) > 0 else ["stratification"] diff --git a/tests/framework/results/test_stratification.py b/tests/framework/results/test_stratification.py index c0a881689..00e47f58d 100644 --- a/tests/framework/results/test_stratification.py +++ b/tests/framework/results/test_stratification.py @@ -1,143 +1,149 @@ +import re + +import numpy as np +import pandas as pd import pytest from tests.framework.results.helpers import ( - CATEGORIES, + HOUSE_CATEGORIES, NAME, - SOURCES, + NAME_COLUMNS, STUDENT_HOUSES, STUDENT_TABLE, sorting_hat_bad_mapping, sorting_hat_serial, - sorting_hat_vector, + sorting_hat_vectorized, ) from vivarium.framework.results.manager import ResultsManager -from vivarium.framework.results.stratification import Stratification +from vivarium.framework.results.stratification import ( + Stratification, + get_mapped_col_name, + get_original_col_name, +) -######### -# Tests # -######### @pytest.mark.parametrize( - "name, sources, categories, mapper, is_vectorized, expected_output", + "mapper, is_vectorized", [ ( # expected output for vectorized - NAME, - SOURCES, - CATEGORIES, - sorting_hat_vector, + sorting_hat_vectorized, True, - STUDENT_HOUSES, ), ( # expected output for non-vectorized - NAME, - SOURCES, - CATEGORIES, sorting_hat_serial, False, - STUDENT_HOUSES, ), ], ids=["vectorized_mapper", "non-vectorized_mapper"], ) -def test_stratification(name, sources, categories, mapper, is_vectorized, expected_output): - my_stratification = Stratification(name, sources, categories, mapper, is_vectorized) - output = my_stratification(STUDENT_TABLE)[name] - assert output.eq(expected_output).all() +def test_stratification(mapper, is_vectorized): + my_stratification = Stratification( + name=NAME, + sources=NAME_COLUMNS, + categories=HOUSE_CATEGORIES, + excluded_categories=[], + mapper=mapper, + is_vectorized=is_vectorized, + ) + output = my_stratification(STUDENT_TABLE) + assert output.eq(STUDENT_HOUSES).all() @pytest.mark.parametrize( - "name, sources, categories, mapper, is_vectorized, expected_exception", + "sources, categories, mapper, msg_match", [ - ( # empty sources list with no defined mapper (default mapper) - NAME, + ( [], - CATEGORIES, + HOUSE_CATEGORIES, None, - True, - ValueError, + f"No mapper provided for stratification {NAME} with 0 stratification sources.", ), - ( # sources list with more than one column with no defined mapper (default mapper) - NAME, - SOURCES, - CATEGORIES, + ( + NAME_COLUMNS, + HOUSE_CATEGORIES, None, - True, - ValueError, + f"No mapper provided for stratification {NAME} with {len(NAME_COLUMNS)} stratification sources.", ), - ( # empty sources list with no defined mapper (default mapper) - NAME, + ( [], - CATEGORIES, - None, - True, - ValueError, + HOUSE_CATEGORIES, + sorting_hat_vectorized, + "The sources argument must be non-empty.", ), - ( # empty categories list - NAME, - SOURCES, + ( + NAME_COLUMNS, [], - None, - True, - ValueError, + FileNotFoundError, + "The categories argument must be non-empty.", ), ], + ids=[ + "no_mapper_empty_sources", + "no_mapper_multiple_sources", + "with_mapper_empty_sources", + "empty_categories", + ], ) -def test_stratification_init_raises( - name, sources, categories, mapper, is_vectorized, expected_exception -): - with pytest.raises(expected_exception): - assert Stratification(name, sources, categories, mapper, is_vectorized) +def test_stratification_init_raises(sources, categories, mapper, msg_match): + with pytest.raises(ValueError, match=re.escape(msg_match)): + Stratification(NAME, sources, categories, [], mapper, True) @pytest.mark.parametrize( - "name, sources, categories, mapper, is_vectorized, expected_exception", + "sources, mapper, is_vectorized, expected_exception, error_match", [ ( - NAME, - SOURCES, - CATEGORIES, + NAME_COLUMNS, sorting_hat_bad_mapping, False, ValueError, + "Invalid values mapped to hogwarts_house: ['pancakes']", ), ( - NAME, ["middle_initial"], - CATEGORIES, - sorting_hat_vector, + sorting_hat_vectorized, True, KeyError, + "None of [Index(['middle_initial'], dtype='object')] are in the [columns]", ), ( - NAME, - SOURCES, - CATEGORIES, + NAME_COLUMNS, sorting_hat_serial, True, - Exception, + Exception, # Can be any exception + "", # Can be any error message ), ( - NAME, - SOURCES, - CATEGORIES, - sorting_hat_vector, + NAME_COLUMNS, + sorting_hat_vectorized, False, - Exception, + Exception, # Can be any exception + "", # Can be any error message + ), + ( + NAME_COLUMNS, + lambda df: pd.Series(np.nan, index=df.index), + True, + ValueError, + f"Invalid values mapped to hogwarts_house: [{np.nan}]", ), ], ids=[ - "category_not_in_categories", + "unknown_category", "source_not_in_population_columns", "vectorized_with_serial_mapper", - "not_vectorized_with_serial_mapper", + "not_vectorized_with_vectorized_mapper", + "mapper_returns_null", ], ) def test_stratification_call_raises( - name, sources, categories, mapper, is_vectorized, expected_exception + sources, mapper, is_vectorized, expected_exception, error_match ): - my_stratification = Stratification(name, sources, categories, mapper, is_vectorized) - with pytest.raises(expected_exception): - raise my_stratification(STUDENT_TABLE) + my_stratification = Stratification( + NAME, sources, HOUSE_CATEGORIES, [], mapper, is_vectorized + ) + with pytest.raises(expected_exception, match=re.escape(error_match)): + my_stratification(STUDENT_TABLE) @pytest.mark.parametrize("default_stratifications", [["age", "sex"], ["age"], []]) @@ -150,3 +156,22 @@ def test_setting_default_stratifications(default_stratifications, mocker): mgr.setup(builder) assert mgr._results_context.default_stratifications == default_stratifications + + +def test_get_mapped_column_name(): + assert get_mapped_col_name("foo") == "foo_mapped_values" + + +@pytest.mark.parametrize( + "col_name, expected", + [ + ("foo_mapped_values", "foo"), + ("foo", "foo"), + ("foo_mapped_values_mapped_values", "foo_mapped_values"), + ("foo_mapped_values2", "foo_mapped_values2"), + ("_mapped_values_foo", "_mapped_values_foo"), + ("_mapped_values_foo_mapped_values", "_mapped_values_foo"), + ], +) +def test_get_original_col_name(col_name, expected): + assert get_original_col_name(col_name) == expected