From 889a8050f9cdc677e97e4c1a1437c5d9835bd6c9 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Mon, 19 Aug 2024 12:32:37 -0600 Subject: [PATCH] update results management system docstrings (#464) * update docstrings * pin sphinx-rtd-theme>=0.6 * various type hint updates * change Stratification __call__() method stratify() * move shared custom types to a new types.py module --- CHANGELOG.rst | 4 + setup.py | 2 +- src/vivarium/framework/artifact/hdf.py | 6 +- src/vivarium/framework/engine.py | 3 +- src/vivarium/framework/event.py | 2 +- src/vivarium/framework/lookup/__init__.py | 3 +- .../framework/lookup/interpolation.py | 6 +- src/vivarium/framework/lookup/manager.py | 2 +- src/vivarium/framework/lookup/table.py | 6 +- src/vivarium/framework/results/context.py | 177 +++++++++---- src/vivarium/framework/results/interface.py | 244 ++++++++--------- src/vivarium/framework/results/manager.py | 158 +++++++---- src/vivarium/framework/results/observation.py | 248 +++++++++++++----- src/vivarium/framework/results/observer.py | 20 +- .../framework/results/stratification.py | 111 ++++++-- src/vivarium/framework/state_machine.py | 2 +- src/vivarium/framework/time.py | 10 +- src/vivarium/framework/values.py | 8 +- src/vivarium/interface/interactive.py | 2 +- src/vivarium/types.py | 12 + tests/framework/components/test_manager.py | 2 +- tests/framework/results/test_interface.py | 24 +- tests/framework/results/test_observer.py | 4 +- .../framework/results/test_stratification.py | 14 +- tests/framework/test_state_machine.py | 2 +- 25 files changed, 698 insertions(+), 374 deletions(-) create mode 100644 src/vivarium/types.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d550adba7..ea5a4def6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**TBD - TBD** + + - Update results-related docstrings + **3.0.1- 08/20/24** - Create script to find matching dependency branches diff --git a/setup.py b/setup.py index 8fe2b7566..896eeadd4 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ doc_requirements = [ "sphinx>=4.0", - "sphinx-rtd-theme", + "sphinx-rtd-theme>=0.6", "sphinx-click", "IPython", "matplotlib", diff --git a/src/vivarium/framework/artifact/hdf.py b/src/vivarium/framework/artifact/hdf.py index cd36863ed..35bd098ea 100644 --- a/src/vivarium/framework/artifact/hdf.py +++ b/src/vivarium/framework/artifact/hdf.py @@ -48,7 +48,7 @@ import tables from tables.nodes import filenode -PandasObj = (pd.DataFrame, pd.Series) +_PandasObj = (pd.DataFrame, pd.Series) #################### # Public interface # @@ -106,7 +106,7 @@ def write(path: Union[str, Path], entity_key: str, data: Any): path = _get_valid_hdf_path(path) entity_key = EntityKey(entity_key) - if isinstance(data, PandasObj): + if isinstance(data, _PandasObj): _write_pandas_data(path, entity_key, data) else: _write_json_blob(path, entity_key, data) @@ -330,7 +330,7 @@ def _get_valid_hdf_path(path: Union[str, Path]) -> Path: return path -def _write_pandas_data(path: Path, entity_key: EntityKey, data: Union[PandasObj]): +def _write_pandas_data(path: Path, entity_key: EntityKey, data: Union[_PandasObj]): """Write data in a pandas format to an HDF file. This method currently supports :class:`pandas DataFrame` objects, with or diff --git a/src/vivarium/framework/engine.py b/src/vivarium/framework/engine.py index e3c47ba9a..cf998caec 100644 --- a/src/vivarium/framework/engine.py +++ b/src/vivarium/framework/engine.py @@ -44,8 +44,9 @@ from vivarium.framework.randomness import RandomnessInterface from vivarium.framework.resource import ResourceInterface from vivarium.framework.results import ResultsInterface -from vivarium.framework.time import Time, TimeInterface +from vivarium.framework.time import TimeInterface from vivarium.framework.values import ValuesInterface +from vivarium.types import Time class SimulationContext: diff --git a/src/vivarium/framework/event.py b/src/vivarium/framework/event.py index 008fa6110..550dbe8b1 100644 --- a/src/vivarium/framework/event.py +++ b/src/vivarium/framework/event.py @@ -32,8 +32,8 @@ import pandas as pd from vivarium.framework.lifecycle import ConstraintError -from vivarium.framework.time import Time, Timedelta from vivarium.manager import Manager +from vivarium.types import Time, Timedelta class Event(NamedTuple): diff --git a/src/vivarium/framework/lookup/__init__.py b/src/vivarium/framework/lookup/__init__.py index 7ed23eb24..3d41cb28e 100644 --- a/src/vivarium/framework/lookup/__init__.py +++ b/src/vivarium/framework/lookup/__init__.py @@ -3,4 +3,5 @@ LookupTableManager, validate_build_table_parameters, ) -from vivarium.framework.lookup.table import LookupTable, LookupTableData, ScalarValue +from vivarium.framework.lookup.table import LookupTable +from vivarium.types import LookupTableData, ScalarValue diff --git a/src/vivarium/framework/lookup/interpolation.py b/src/vivarium/framework/lookup/interpolation.py index 02367bb9a..ce3c6eb66 100644 --- a/src/vivarium/framework/lookup/interpolation.py +++ b/src/vivarium/framework/lookup/interpolation.py @@ -13,7 +13,7 @@ import numpy as np import pandas as pd -ParameterType = Union[List[List[str]], List[Tuple[str, str, str]]] +_ParameterType = Union[List[List[str]], List[Tuple[str, str, str]]] class Interpolation: @@ -39,7 +39,7 @@ def __init__( self, data: pd.DataFrame, categorical_parameters: Union[List[str], Tuple[str, ...]], - continuous_parameters: ParameterType, + continuous_parameters: _ParameterType, value_columns: Union[List[str], Tuple[str, ...]], order: int, extrapolate: bool, @@ -265,7 +265,7 @@ class Order0Interp: def __init__( self, data, - continuous_parameters: ParameterType, + continuous_parameters: _ParameterType, value_columns: List[str], extrapolate: bool, validate: bool, diff --git a/src/vivarium/framework/lookup/manager.py b/src/vivarium/framework/lookup/manager.py index cc68eae32..12fc56687 100644 --- a/src/vivarium/framework/lookup/manager.py +++ b/src/vivarium/framework/lookup/manager.py @@ -22,10 +22,10 @@ CategoricalTable, InterpolatedTable, LookupTable, - LookupTableData, ScalarTable, ) from vivarium.manager import Manager +from vivarium.types import LookupTableData if TYPE_CHECKING: from vivarium.framework.engine import Builder diff --git a/src/vivarium/framework/lookup/table.py b/src/vivarium/framework/lookup/table.py index 9cc6f6d7d..b5149276e 100644 --- a/src/vivarium/framework/lookup/table.py +++ b/src/vivarium/framework/lookup/table.py @@ -13,17 +13,13 @@ import dataclasses from abc import ABC, abstractmethod -from datetime import datetime, timedelta -from numbers import Number from typing import Callable, List, Tuple, Union import numpy as np import pandas as pd from vivarium.framework.lookup.interpolation import Interpolation - -ScalarValue = Union[Number, timedelta, datetime] -LookupTableData = Union[ScalarValue, pd.DataFrame, List[ScalarValue], Tuple[ScalarValue]] +from vivarium.types import ScalarValue @dataclasses.dataclass diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 99562fabf..17776b84f 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -1,7 +1,13 @@ +""" +=============== +Results Context +=============== +""" + from __future__ import annotations from collections import defaultdict -from typing import Callable, Generator, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Generator, List, Optional, Tuple, Type, Union import pandas as pd from pandas.core.groupby import DataFrameGroupBy @@ -15,30 +21,40 @@ get_mapped_col_name, get_original_col_name, ) +from vivarium.types import ScalarValue class ResultsContext: - """ - Manager context for organizing observations and the stratifications they require. + """Manager for organizing observations and their required stratifications. - This context object is wholly contained by the manager :class:`vivarium.framework.results.manager.ResultsManger`. - Stratifications can be added to the context through the manager via the - :meth:`vivarium.framework.results.context.ResultsContext.add_observation` method. + This context object is wholly contained by :class:`ResultsManager `. + Stratifications and observations can be added to the context through the manager via the + :meth:`vivarium.framework.results.context.ResultsContext.add_stratification` and + :meth:`vivarium.framework.results.context.ResultsContext.register_observation` methods, respectively. + + Attributes + ---------- + default_stratifications + List of column names to use for stratifying results. + stratifications + List of :class:`Stratification ` + objects to be applied to results. + excluded_categories + Dictionary of possible per-metric stratification values to be excluded + from results processing. + observations + Dictionary of observation details. It is of the format + {lifecycle_phase: {(pop_filter, stratifications): list[Observation]}}. + Allowable lifecycle_phases are "time_step__prepare", "time_step", + "time_step__cleanup", and "collect_metrics". + logger + Logger for the results context. """ def __init__(self) -> None: self.default_stratifications: List[str] = [] self.stratifications: List[Stratification] = [] self.excluded_categories: dict[str, list[str]] = {} - # keys are event names: [ - # "time_step__prepare", - # "time_step", - # "time_step__cleanup", - # "collect_metrics", - # ] - # values are dicts with - # key (filter, grouper) - # value Observation self.observations: defaultdict = defaultdict(lambda: defaultdict(list)) @property @@ -46,6 +62,11 @@ def name(self) -> str: return "results_context" def setup(self, builder: Builder) -> None: + """Set up the results context. + + This method is called by the :class:`ResultsManager ` + during the setup phase of that object. + """ self.logger = builder.logging.get_logger(self.name) self.excluded_categories = ( builder.configuration.stratification.excluded_categories.to_dict() @@ -53,6 +74,18 @@ def setup(self, builder: Builder) -> None: # noinspection PyAttributeOutsideInit def set_default_stratifications(self, default_grouping_columns: List[str]) -> None: + """Set the default stratifications to be used by stratified observations. + + Parameters + ---------- + default_grouping_columns + List of stratifications to be used. + + Raises + ------ + ResultsConfigurationError + If the `self.default_stratifications` attribute has already been set. + """ if self.default_stratifications: raise ResultsConfigurationError( "Multiple calls are being made to set default grouping columns " @@ -66,35 +99,42 @@ def add_stratification( sources: List[str], categories: List[str], excluded_categories: Optional[List[str]], - mapper: Optional[Callable[[Union[pd.Series[str], pd.DataFrame]], pd.Series[str]]], + mapper: Optional[ + Union[ + Callable[[Union[pd.Series, pd.DataFrame]], pd.Series[str]], + Callable[[ScalarValue], str], + ] + ], is_vectorized: bool, ) -> None: - """Add a stratification to the context. + """Add a stratification to the results context. Parameters ---------- name - Name of the of the column created by the stratification. + Name of the stratification. sources - A list of the columns and values needed for the mapper to determinate - categorization. + A list of the columns and values needed as input for the `mapper`. categories - List of string values that the mapper is allowed to output. + Exhaustive list of all possible stratification values. excluded_categories - List of mapped string values to be excluded from results processing. + List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. mapper - A callable that emits values in `categories` given inputs from columns - and values in the `requires_columns` and `requires_values`, respectively. + A callable that maps the columns and value pipelines specified by + `sources` to the stratification categories. It can either map the entire + population or an individual simulant. A simulation will fail if the `mapper` + ever produces an invalid value. is_vectorized - `True` if the mapper function expects a `DataFrame`, and `False` if it - expects a row of the `DataFrame` and should be used by calling :func:`df.apply`. - + True if the `mapper` function will map the entire population, and False + if it will only map a single simulant. - Returns + Raises ------ - None - + ValueError + - If the stratification `name` is already used. + - If there are duplicate `categories`. + - If any `excluded_categories` are not in `categories`. """ already_used = [ stratification @@ -151,29 +191,28 @@ def register_observation( when: str, **kwargs, ) -> None: - """Add an observation to the context. + """Add an observation to the results context. Parameters ---------- observation_type - Class type of the observation to register. + Specific class type of observation to register. name - Name of the metric to observe and result file. + Name of the observation. It will also be the name of the output results file + for this particular observation. pop_filter - A Pandas query filter string to filter the population down to the - simulants who should be considered for the observation. + A Pandas query filter string to filter the population down to the simulants who should + be considered for the observation. when - String name of the phase of a time-step the observation should happen. - Valid values are: `"time_step__prepare"`, `"time_step"`, - `"time_step__cleanup"`, `"collect_metrics"`. - kwargs - Additional keyword arguments to pass to the observation constructor. + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". + **kwargs + Additional keyword arguments to be passed to the observation's constructor. - - Returns + Raises ------ - None - + ValueError + If the observation `name` is already used. """ already_used = None if self.observations: @@ -190,6 +229,9 @@ def register_observation( raise ValueError( f"Observation name '{name}' is already used: {str(already_used)}." ) + + # Instantiate the observation and add it and its (pop_filter, stratifications) + # tuple as a key-value pair to the self.observations[when] dictionary. observation = observation_type(name=name, pop_filter=pop_filter, when=when, **kwargs) self.observations[observation.when][ (observation.pop_filter, observation.stratifications) @@ -206,7 +248,31 @@ def gather_results( None, None, ]: - """Generate current results for all observations at this lifecycle phase and event.""" + """Generate and yield current results for all observations at this lifecycle + phase and event. Each set of results are stratified and grouped by + all registered stratifications as well as filtered by their respective + observation's pop_filter. + + Parameters + ---------- + population + The current population DataFrame. + lifecycle_phase + The current lifecycle phase. + event + The current Event. + + Yields + ------ + A tuple containing each observation's newly observed results, the name of + the observation, and the observations results updater function. Note that + it yields (None, None, None) if the filtered population is empty. + + Raises + ------ + ValueError + If a stratification's temporary column name already exists in the population DataFrame. + """ for stratification in self.stratifications: # Add new columns of mapped values to the population to prevent name collisions @@ -217,7 +283,7 @@ def gather_results( "already exists in the state table or as a pipeline which is a required " "name for stratifying results - choose a different name." ) - population[new_column] = stratification(population) + population[new_column] = stratification.stratify(population) # Optimization: We store all the producers by pop_filter and stratifications # so that we only have to apply them once each time we compute results. @@ -249,9 +315,7 @@ def _filter_population( pop_filter: str, stratification_names: Optional[tuple[str, ...]], ) -> pd.DataFrame: - """Filter the population based on the filter string as well as any - excluded stratification categories - """ + """Filter out simulants not to observe.""" pop = population.query(pop_filter) if pop_filter else population.copy() if stratification_names: # Drop all rows in the mapped_stratification columns that have NaN values @@ -268,10 +332,15 @@ def _filter_population( def _get_groups( stratifications: Tuple[str, ...], filtered_pop: pd.DataFrame ) -> DataFrameGroupBy: - """Group the population by stratifications. - NOTE: Stratifications at this point can be an empty tuple. - HACK: It's a bit hacky how we are handling the groupby object if there - are no stratifications. The alternative is to use the entire population + """Group the population by stratification. + + Notes + ----- + Stratifications at this point can be an empty tuple. + + HACK: If there are no `stratifications` (i.e. it's an empty tuple), we + create a single group of the entire `filtered_pop` index and assign + it a name of "all". The alternative is to use the entire population instead of a groupby object, but then we would need to handle the different ways the aggregator can behave. """ @@ -286,7 +355,7 @@ def _get_groups( return pop_groups def _rename_stratification_columns(self, results: pd.DataFrame) -> None: - """convert stratified mapped index names to original""" + """Convert the temporary stratified mapped index names back to their original names.""" if isinstance(results.index, pd.MultiIndex): idx_names = [get_original_col_name(name) for name in results.index.names] results.rename_axis(index=idx_names, inplace=True) diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index 228bcdefc..88fac452b 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -1,6 +1,16 @@ +""" +========================== +Vivarium Results Interface +========================== + +This module provides a :class:`ResultsInterface ` class with +methods to register stratifications and results producers (referred to as "observations") +to a simulation. +""" + from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import pandas as pd @@ -11,9 +21,9 @@ StratifiedObservation, UnstratifiedObservation, ) +from vivarium.types import ScalarValue if TYPE_CHECKING: - # cyclic import from vivarium.framework.results.manager import ResultsManager @@ -39,19 +49,8 @@ class ResultsInterface: modeling, but are required for the stratification of produced results. The purpose of this interface is to provide controlled access to a results - backend by means of the builder object. It exposes methods - to register stratifications, set default stratifications, and register - results producers. There is a special case for stratifications generated - by binning continuous data into categories. - - The expected use pattern would be for a single component to register all - stratifications required by the model using :func:`register_default_stratifications`, - :func:`register_stratification`, and :func:`register_binned_stratification` - as necessary. A “binned stratification” is a stratification special case for - the very common situation when a single continuous value needs to be binned into - categorical bins. The `is_vectorized` argument should be True if the mapper - function expects a DataFrame corresponding to the whole population, and False - if it expects a row of the DataFrame corresponding to a single simulant. + backend by means of the builder object; it exposes methods to register both + stratifications and results producers (referred to as "observations"). """ def __init__(self, manager: "ResultsManager") -> None: @@ -60,7 +59,6 @@ def __init__(self, manager: "ResultsManager") -> None: @property def name(self) -> str: - """The name of this ResultsInterface.""" return self._name ################################## @@ -74,39 +72,42 @@ def register_stratification( name: str, categories: List[str], excluded_categories: Optional[List[str]] = None, - mapper: Optional[Callable[[pd.DataFrame], pd.Series[str]]] = None, + mapper: Optional[ + Union[ + Callable[[Union[pd.Series, pd.DataFrame]], pd.Series[str]], + Callable[[ScalarValue], str], + ] + ] = None, is_vectorized: bool = False, requires_columns: List[str] = [], requires_values: List[str] = [], ) -> None: - """Register quantities to observe. + """Registers a stratification that can be used by stratified observations. Parameters ---------- name - Name of the of the column created by the stratification. + Name of the stratification. categories - List of string values that the mapper is allowed to output. + Exhaustive list of all possible stratification values. excluded_categories - List of mapped string values to be excluded from results processing. + List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. mapper - A callable that emits values in `categories` given inputs from columns - and values in the `requires_columns` and `requires_values`, respectively. + A callable that maps the columns and value pipelines specified by the + `requires_columns` and `requires_values` arguments to the stratification + categories. It can either map the entire population or an individual + simulant. A simulation will fail if the `mapper` ever produces an invalid + value. is_vectorized - `True` if the mapper function expects a `DataFrame`, and `False` if it - expects a row of the `DataFrame` and should be used by calling :func:`df.apply`. + True if the `mapper` function will map the entire population, and False + if it will only map a single simulant. requires_columns - A list of the state table columns that already need to be present - and populated in the state table before the pipeline modifier - is called. + A list of the state table columns that are required by the `mapper` + to produce the stratification. requires_values - A list of the value pipelines that need to be properly sourced - before the pipeline modifier is called. - - Returns - ------ - None + A list of the value pipelines that are required by the `mapper` to + produce the stratification. """ self._manager.register_stratification( name, @@ -128,14 +129,14 @@ def register_binned_stratification( target_type: str = "column", **cut_kwargs: Dict, ) -> None: - """Register a continuous `target` quantity to observe into bins in a `binned_column`. + """Registers a binned stratification that can be used by stratified observations. Parameters ---------- target - String name of the state table column or value pipeline used to stratify. + Name of the state table column or value pipeline to be binned. binned_column - String name of the column for the binned quantities. + Name of the (binned) stratification. bin_edges List of scalars defining the bin edges, passed to :meth: pandas.cut. The length must be equal to the length of `labels` plus 1. @@ -143,16 +144,13 @@ def register_binned_stratification( List of string labels for bins. The length must be equal to the length of `bin_edges` minus 1. excluded_categories - List of mapped string values to be excluded from results processing. + List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. target_type - "column" or "value" + Type specification of the `target` to be binned. "column" if it's a + state table column or "value" if it's a value pipeline. **cut_kwargs Keyword arguments for :meth: pandas.cut. - - Returns - ------ - None """ self._manager.register_binned_stratification( target, @@ -187,43 +185,44 @@ def register_stratified_observation( aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: - """Provide the results system all the information it needs to perform a - stratified observation. + """Registers a stratified observation to the results system. Parameters ---------- name - String name for the observation. + Name of the observation. It will also be the name of the output results file + for this particular observation. pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. when - String name of the phase of a time-step the observation should happen. Valid values are: - `"time_step__prepare"`, `"time_step"`, `"time_step__cleanup"`, `"collect_metrics"`. + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_columns - A list of the state table columns that are required by either the pop_filter or the aggregator. + List of the state table columns that are required by either the `pop_filter` or the `aggregator`. requires_values - A list of the value pipelines that are required by either the pop_filter or the aggregator. + List of the value pipelines that are required by either the `pop_filter` or the `aggregator`. results_updater - A function that updates existing observation results with newly gathered ones. + Function that updates existing raw observation results with newly gathered results. results_formatter - A function that formats the observation results. + Function that formats the raw observation results. additional_stratifications - A list of additional :class:`stratification ` - names by which to stratify. + List of additional :class:`Stratification ` + names by which to stratify this observation by. excluded_stratifications - A list of default :class:`stratification ` - names to remove from the observation. + List of default :class:`Stratification ` + names to remove from this observation. aggregator_sources - A list of population view columns to be used in the aggregator. + List of population view columns to be used in the `aggregator`. aggregator - A function that computes the quantity for the observation. + Function that computes the quantity for this observation. to_observe - A function that determines whether to perform an observation on this Event. + Function that determines whether to perform an observation on this Event. - Returns + Raises ------ - None + ValueError + If any required callable arguments are missing. """ self._check_for_required_callables(name, {"results_updater": results_updater}) self._manager.register_observation( @@ -243,19 +242,6 @@ def register_stratified_observation( to_observe=to_observe, ) - @staticmethod - def _check_for_required_callables( - observation_name: str, required_callables: Dict[str, Callable] - ) -> None: - missing = [] - for arg_name, callable in required_callables.items(): - if callable == _required_function_placeholder: - missing.append(arg_name) - if len(missing) > 0: - raise ValueError( - f"Observation '{observation_name}' is missing required callable(s): {missing}" - ) - def register_unstratified_observation( self, name: str, @@ -274,45 +260,36 @@ def register_unstratified_observation( ] = lambda measure, results: results, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: - """Provide the results system all the information it needs to perform a - stratified observation. + """Registers an unstratified observation to the results system. Parameters ---------- name - String name for the observation. + Name of the observation. It will also be the name of the output results file + for this particular observation. pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. when - String name of the phase of a time-step the observation should happen. Valid values are: - `"time_step__prepare"`, `"time_step"`, `"time_step__cleanup"`, `"collect_metrics"`. + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_columns - A list of the state table columns that are required by either the pop_filter or the aggregator. + List of the state table columns that are required by either the `pop_filter` or the `aggregator`. requires_values - A list of the value pipelines that are required by either the pop_filter or the aggregator. + List of the value pipelines that are required by either the `pop_filter` or the `aggregator`. results_gatherer - A function that gathers the latest observation results. + Function that gathers the latest observation results. results_updater - A function that updates existing observation results with newly gathered ones. + Function that updates existing raw observation results with newly gathered results. results_formatter - A function that formats the observation results. - additional_stratifications - A list of additional :class:`stratification ` - names by which to stratify. - excluded_stratifications - A list of default :class:`stratification ` - names to remove from the observation. - aggregator_sources - A list of population view columns to be used in the aggregator. - aggregator - A function that computes the quantity for the observation. + Function that formats the raw observation results. to_observe - A function that determines whether to perform an observation on this Event. + Function that determines whether to perform an observation on this Event. - Returns + Raises ------ - None + ValueError + If any required callable arguments are missing. """ required_callables = { "results_gatherer": results_gatherer, @@ -349,42 +326,40 @@ def register_adding_observation( aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: - """Provide the results system all the information it needs to perform the observation. + """Registers an adding observation to the results system; that is, + one that adds/sums new results to existing result values. Note that an adding + observation is a specific type of stratified observation. Parameters ---------- name - String name for the observation. + Name of the observation. It will also be the name of the output results file + for this particular observation. pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. when - String name of the phase of a time-step the observation should happen. Valid values are: - `"time_step__prepare"`, `"time_step"`, `"time_step__cleanup"`, `"collect_metrics"`. + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_columns - A list of the state table columns that are required by either the pop_filter or the aggregator. + List of the state table columns that are required by either the `pop_filter` or the `aggregator`. requires_values - A list of the value pipelines that are required by either the pop_filter or the aggregator. + List of the value pipelines that are required by either the `pop_filter` or the `aggregator`. results_formatter - A function that formats the observation results. + Function that formats the raw observation results. additional_stratifications - A list of additional :class:`stratification ` - names by which to stratify. + List of additional :class:`Stratification ` + names by which to stratify this observation by. excluded_stratifications - A list of default :class:`stratification ` - names to remove from the observation. + List of default :class:`Stratification ` + names to remove from this observation. aggregator_sources - A list of population view columns to be used in the aggregator. + List of population view columns to be used in the `aggregator`. aggregator - A function that computes the quantity for the observation. + Function that computes the quantity for this observation. to_observe - A function that determines whether to perform an observation on this Event. - - Returns - ------ - None + Function that determines whether to perform an observation on this Event. """ - self._manager.register_observation( observation_type=AddingObservation, is_stratified=True, @@ -413,30 +388,29 @@ def register_concatenating_observation( ] = lambda measure, results: results, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: - """Provide the results system all the information it needs to perform the observation. + """Registers a concatenating observation to the results system; that is, + one that concatenates new results to existing results. Note that a + concatenating observation is a specific type of unstratified observation. Parameters ---------- name - String name for the observation. + Name of the observation. It will also be the name of the output results file + for this particular observation. pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. when - String name of the phase of a time-step the observation should happen. Valid values are: - `"time_step__prepare"`, `"time_step"`, `"time_step__cleanup"`, `"collect_metrics"`. + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_columns - A list of the state table columns that are required by either the pop_filter or the aggregator. + List of the state table columns that are required by either the `pop_filter` or the `aggregator`. requires_values - A list of the value pipelines that are required by either the pop_filter or the aggregator. + List of the value pipelines that are required by either the `pop_filter` or the `aggregator`. results_formatter - A function that formats the observation results. + Function that formats the raw observation results. to_observe - A function that determines whether to perform an observation on this Event. - - Returns - ------ - None + Function that determines whether to perform an observation on this Event. """ included_columns = ["event_time"] + requires_columns + requires_values self._manager.register_observation( @@ -451,3 +425,17 @@ def register_concatenating_observation( included_columns=included_columns, to_observe=to_observe, ) + + @staticmethod + def _check_for_required_callables( + observation_name: str, required_callables: Dict[str, Callable] + ) -> None: + """Raises a ValueError if any required callable arguments are missing.""" + missing = [] + for arg_name, callable in required_callables.items(): + if callable == _required_function_placeholder: + missing.append(arg_name) + if len(missing) > 0: + raise ValueError( + f"Observation '{observation_name}' is missing required callable(s): {missing}" + ) diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index acc012c5b..631be1f59 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -1,8 +1,14 @@ +""" +====================== +Results System Manager +====================== +""" + from __future__ import annotations from collections import defaultdict from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import pandas as pd @@ -10,6 +16,7 @@ from vivarium.framework.results.context import ResultsContext from vivarium.framework.values import Pipeline from vivarium.manager import Manager +from vivarium.types import ScalarValue if TYPE_CHECKING: from vivarium.framework.engine import Builder @@ -23,13 +30,9 @@ class SourceType(Enum): class ResultsManager(Manager): """Backend manager object for the results management system. - The :class:`ResultManager` actually performs the actions needed to - stratify and observe results. It contains the public methods used by the - :class:`ResultsInterface` to register stratifications and observations, - which provide it with lists of methods to apply in their respective areas. - It is able to record observations at any of the time-step sub-steps - (`time_step__prepare`, `time_step`, `time_step__cleanup`, and - `collect_metrics`). + This class contains the public methods used by the :class:`ResultsInterface ` + to register stratifications and observations as well as the :method:`get_results` + method used to retrieve formatted results by the :class:`ResultsContext `. """ CONFIGURATION_DEFAULTS = { @@ -48,16 +51,23 @@ def __init__(self) -> None: @property def name(self) -> str: - """The name of this ResultsManager.""" return self._name def get_results(self) -> Dict[str, pd.DataFrame]: """Return the measure-specific formatted results in a dictionary. - NOTE: self._results_context.observations is a list where each item is a dictionary - of the form {lifecycle_phase: {(pop_filter, stratifications): List[Observation]}}. + Notes + ----- + self._results_context.observations is a list where each item is a dictionary + of the form {lifecycle_phase: {(pop_filter, stratification_names): List[Observation]}}. We use a triple-nested for loop to iterate over only the list of Observations - (i.e. we do not need the lifecycle_phase, pop_filter, or stratifications). + (i.e. we do not need the lifecycle_phase, pop_filter, or stratification_names + for this method). + + Returns + ------- + Dict[str, pd.DataFrame] + A dictionary of formatted results for each measure. """ formatted = {} for observation_details in self._results_context.observations.values(): @@ -72,6 +82,7 @@ def get_results(self) -> Dict[str, pd.DataFrame]: # noinspection PyAttributeOutsideInit def setup(self, builder: "Builder") -> None: + """Set up the results manager.""" self._results_context.setup(builder) self.logger = builder.logging.get_logger(self.name) @@ -90,7 +101,7 @@ def setup(self, builder: "Builder") -> None: self.set_default_stratifications(builder) def on_post_setup(self, _: Event) -> None: - """Initialize results with 0s DataFrame' for each measure and all stratifications""" + """Initialize results for each measure.""" registered_stratifications = self._results_context.stratifications used_stratifications = set() @@ -124,22 +135,23 @@ def on_post_setup(self, _: Event) -> None: ) def on_time_step_prepare(self, event: Event) -> None: + """Define the listener callable for the time_step__prepare phase.""" self.gather_results("time_step__prepare", event) def on_time_step(self, event: Event) -> None: + """Define the listener callable for the time_step phase.""" self.gather_results("time_step", event) def on_time_step_cleanup(self, event: Event) -> None: + """Define the listener callable for the time_step__cleanup phase.""" self.gather_results("time_step__cleanup", event) def on_collect_metrics(self, event: Event) -> None: + """Define the listener callable for the collect_metrics phase.""" self.gather_results("collect_metrics", event) def gather_results(self, lifecycle_phase: str, event: Event) -> None: - """Update the existing results with new results. Any columns in the - results group that are not already in the existing results are initialized - with 0.0. - """ + """Update existing results with any new results.""" population = self._prepare_population(event) if population.empty: return @@ -156,6 +168,17 @@ def gather_results(self, lifecycle_phase: str, event: Event) -> None: ########################## def set_default_stratifications(self, builder: Builder) -> None: + """Set the default stratifications for the results context. + + This passes the default stratifications from the configuration to the + :class:`ResultsContext ` + :meth:`set_default_stratifications` method to be set. + + Parameters + ---------- + builder + The builder object for the simulation. + """ default_stratifications = builder.configuration.stratification.default self._results_context.set_default_stratifications(default_stratifications) @@ -164,39 +187,44 @@ def register_stratification( name: str, categories: List[str], excluded_categories: Optional[List[str]], - mapper: Optional[Callable[[Union[pd.Series[str], pd.DataFrame]], pd.Series[str]]], + mapper: Optional[ + Union[ + Callable[[Union[pd.Series, pd.DataFrame]], pd.Series[str]], + Callable[[ScalarValue], str], + ] + ], is_vectorized: bool, requires_columns: List[str] = [], requires_values: List[str] = [], ) -> None: - """Manager-level stratification registration, including resources and the stratification itself. + """Manager-level stratification registration. Adds a stratification + to the :class:`ResultsContext ` + as well as the stratification's required resources to this manager. Parameters ---------- name - Name of the of the column created by the stratification. + Name of the stratification. categories - List of string values that the mapper is allowed to output. + Exhaustive list of all possible stratification values. excluded_categories - List of mapped string values to be excluded from results processing. + List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. mapper - A callable that emits values in `categories` given inputs from columns - and values in the `requires_columns` and `requires_values`, respectively. + A callable that maps the columns and value pipelines specified by the + `requires_columns` and `requires_values` arguments to the stratification + categories. It can either map the entire population or an individual + simulant. A simulation will fail if the `mapper` ever produces an invalid + value. is_vectorized - `True` if the mapper function expects a `DataFrame`, and `False` if it - expects a row of the `DataFrame` and should be used by calling :func:`df.apply`. + True if the `mapper` function will map the entire population, and False + if it will only map a single simulant. requires_columns - A list of the state table columns that already need to be present - and populated in the state table before the pipeline modifier - is called. + A list of the state table columns that are required by the `mapper` + to produce the stratification. requires_values - A list of the value pipelines that need to be properly sourced - before the pipeline modifier is called. - - Returns - ------ - None + A list of the value pipelines that are required by the `mapper` to + produce the stratification. """ self.logger.debug(f"Registering stratification {name}") target_columns = list(requires_columns) + list(requires_values) @@ -216,36 +244,33 @@ def register_binned_stratification( target_type: str, **cut_kwargs, ) -> None: - """Manager-level registration of a continuous `target` quantity to observe into bins in a `binned_column`. + """Manager-level registration of a continuous `target` quantity to observe + into bins in a `binned_column`. Parameters ---------- target - String name of the state table column or value pipeline used to stratify. + Name of the state table column or value pipeline to be binned. binned_column - String name of the column for the binned quantities. + Name of the (binned) stratification. bin_edges List of scalars defining the bin edges, passed to :meth: pandas.cut. - The length must equal the length of `labels` plus one. - Note that the bins are left edge inclusive, e.g. bin edges [1, 2, 3] - indicate groups [1, 2) and [2, 3). + The length must be equal to the length of `labels` plus 1. labels - List of string labels for bins. The length must equal to the length - of `bin_edges` minus one. + List of string labels for bins. The length must be equal to the length + of `bin_edges` minus 1. excluded_categories - List of mapped string values to be excluded from results processing. + List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. target_type - "column" or "value" + Type specification of the `target` to be binned. "column" if it's a + state table column or "value" if it's a value pipeline. **cut_kwargs Keyword arguments for :meth: pandas.cut. - - Returns - ------ - None """ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: + """Use pandas.cut to bin continuous values""" data = data.squeeze() if not isinstance(data, pd.Series): raise ValueError(f"Expected a Series, but got type {type(data)}.") @@ -259,7 +284,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: f"match the number of labels ({len(labels)})" ) - target_arg = "requires_columns" if target_type == "column" else "required_values" + target_arg = "requires_columns" if target_type == "column" else "requires_values" target_kwargs = {target_arg: [target]} self.register_stratification( @@ -281,10 +306,37 @@ def register_observation( requires_columns: List[str], requires_values: List[str], **kwargs, - ): + ) -> None: + """Manager-level observation registration. Adds an observation to the + :class:`ResultsContext ` + as well as the observation's required resources to this manager. + + Parameters + ---------- + observation_type + Specific class type of observation to register. + is_stratified + True if the observation is a stratified type and False if not. + name + Name of the observation. It will also be the name of the output results file + for this particular observation. + pop_filter + A Pandas query filter string to filter the population down to the simulants who should + be considered for the observation. + when + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". + requires_columns + List of the state table columns that are required by either the `pop_filter` or the `aggregator`. + requires_values + List of the value pipelines that are required by either the `pop_filter` or the `aggregator`. + **kwargs + Additional keyword arguments to be passed to the observation's constructor. + """ self.logger.debug(f"Registering observation {name}") if is_stratified: + # Resolve required stratifications and add to kwargs dictionary additional_stratifications = kwargs.get("additional_stratifications", []) excluded_stratifications = kwargs.get("excluded_stratifications", []) self._warn_check_stratifications( @@ -296,6 +348,7 @@ def register_observation( excluded_stratifications, ) kwargs["stratifications"] = stratifications + # Remove the unused kwargs before passing to the results context registration del kwargs["additional_stratifications"] del kwargs["excluded_stratifications"] @@ -320,6 +373,7 @@ def _get_stratifications( additional_stratifications: List[str] = [], excluded_stratifications: List[str] = [], ) -> Tuple[str, ...]: + """Resolve the stratifications required for the observation.""" stratifications = list( set( self._results_context.default_stratifications @@ -332,6 +386,7 @@ def _get_stratifications( return tuple(sorted(stratifications)) def _add_resources(self, target: List[str], target_type: SourceType) -> None: + """Add required resources to the manager's list of required columns and values.""" if len(target) == 0: return # do nothing on empty lists target_set = set(target) - {"event_time", "current_time", "event_step_size"} @@ -341,6 +396,7 @@ def _add_resources(self, target: List[str], target_type: SourceType) -> None: self._required_values.update([self.get_value(target) for target in target_set]) def _prepare_population(self, event: Event) -> pd.DataFrame: + """Prepare the population for results gathering.""" population = self.population_view.subview(list(self._required_columns)).get( event.index ) diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index 915d73290..aaa8802eb 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -1,9 +1,15 @@ +""" +============ +Observations +============ +""" + from __future__ import annotations import itertools from abc import ABC from dataclasses import dataclass -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Iterable, Optional, Sequence, Tuple, Union import pandas as pd from pandas.api.types import CategoricalDtype @@ -17,23 +23,45 @@ @dataclass class BaseObservation(ABC): - """An abstract base dataclass to be inherited by concrete observations. - This class includes the following attributes: - - `name`: name of the observation and is the measure it is observing - - `pop_filter`: a filter that is applied to the population before the observation is made - - `when`: the phase that the observation is registered to - - `results_initializer`: method that initializes the results - - `results_gatherer`: method that gathers the new observation results - - `results_updater`: method that updates the results with new observations - - `results_formatter`: method that formats the results - - `to_observe`: method that determines whether to observe an event + """An abstract base dataclass to be inherited by concrete observations. It includes + an :meth:`observe` method that determines whether to observe results for a given event. + + Attributes + ---------- + name + Name of the observation. It will also be the name of the output results file + for this particular observation. + pop_filter + A Pandas query filter string to filter the population down to the simulants who should + be considered for the observation. + when + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". + results_initializer + Method or function that initializes the raw observation results + prior to starting the simulation. This could return, for example, an empty + DataFrame or one with a complete set of stratifications as the index and + all values set to 0.0. + results_gatherer + Method or function that gathers the new observation results. + results_updater + Method or function that updates existing raw observation results with newly gathered results. + results_formatter + Method or function that formats the raw observation results. + stratifications + Optional tuple of column names for the observation to stratify by. + to_observe + Method or function that determines whether to perform an observation on this Event. """ name: str pop_filter: str when: str - results_initializer: Callable[..., pd.DataFrame] - results_gatherer: Callable[..., pd.DataFrame] + results_initializer: Callable[[Iterable[str], Iterable[Stratification]], pd.DataFrame] + results_gatherer: Union[ + Callable[[pd.DataFrame, Sequence[str]], pd.DataFrame], + Callable[[pd.DataFrame], pd.DataFrame], + ] results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame] stratifications: Optional[Tuple[str]] @@ -45,6 +73,7 @@ def observe( df: Union[pd.DataFrame, DataFrameGroupBy], stratifications: Optional[tuple[str, ...]], ) -> Optional[pd.DataFrame]: + """Determine whether to observe the given event and, if so, gather the results.""" if not self.to_observe(event): return None else: @@ -55,15 +84,30 @@ def observe( class UnstratifiedObservation(BaseObservation): - """Container class for managing unstratified observations. - This class includes the following attributes: - - `name`: name of the observation and is the measure it is observing - - `pop_filter`: a filter that is applied to the population before the observation is made - - `when`: the phase that the observation is registered to - - `results_gatherer`: method that gathers the new observation results - - `results_updater`: method that updates the results with new observations - - `results_formatter`: method that formats the results - - `to_observe`: method that determines whether to observe an event + """Concrete class for observing results that are not stratified. + + The parent class `stratifications` are set to None and the `results_initializer` + method is explicitly defined. + + Attributes + ---------- + name + Name of the observation. It will also be the name of the output results file + for this particular observation. + pop_filter + A Pandas query filter string to filter the population down to the simulants who should + be considered for the observation. + when + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". + results_gatherer + Method or function that gathers the new observation results. + results_updater + Method or function that updates existing raw observation results with newly gathered results. + results_formatter + Method or function that formats the raw observation results. + to_observe + Method or function that determines whether to perform an observation on this Event. """ def __init__( @@ -80,7 +124,7 @@ def __init__( name=name, pop_filter=pop_filter, when=when, - results_initializer=self.initialize_results, + results_initializer=self.create_empty_df, results_gatherer=results_gatherer, results_updater=results_updater, results_formatter=results_formatter, @@ -89,7 +133,7 @@ def __init__( ) @staticmethod - def initialize_results( + def create_empty_df( requested_stratification_names: set[str], registered_stratifications: list[Stratification], ) -> pd.DataFrame: @@ -98,17 +142,36 @@ def initialize_results( class StratifiedObservation(BaseObservation): - """Container class for managing stratified observations. - This class includes the following attributes: - - `name`: name of the observation and is the measure it is observing - - `pop_filter`: a filter that is applied to the population before the observation is made - - `when`: the phase that the observation is registered to - - `results_updater`: method that updates the results with new observations - - `results_formatter`: method that formats the results - - `stratifications`: a tuple of columns for the observation to stratify by - - `aggregator_sources`: a list of the columns to observe - - `aggregator`: a method that aggregates the `aggregator_sources` - - `to_observe`: method that determines whether to observe an event + """Concrete class for observing stratified results. + + The parent class `results_initializer` and `results_gatherer` methods are + explicitly defined and stratification-specific attributes `aggregator_sources` + and `aggregator` are added. + + Attributes + ---------- + name + Name of the observation. It will also be the name of the output results file + for this particular observation. + pop_filter + A Pandas query filter string to filter the population down to the simulants who should + be considered for the observation. + when + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". + results_updater + Method or function that updates existing raw observation results with newly gathered results. + results_formatter + Method or function that formats the raw observation results. + stratifications + Tuple of column names for the observation to stratify by. If empty, + the observation is aggregated over the entire population. + aggregator_sources + List of population view columns to be used in the `aggregator`. + aggregator + Method or function that computes the quantity for this observation. + to_observe + Method or function that determines whether to perform an observation on this Event. """ def __init__( @@ -127,8 +190,8 @@ def __init__( name=name, pop_filter=pop_filter, when=when, - results_initializer=self.initialize_results, - results_gatherer=self.results_gatherer, + results_initializer=self.create_expanded_df, + results_gatherer=self.get_complete_stratified_results, results_updater=results_updater, results_formatter=results_formatter, stratifications=stratifications, @@ -138,7 +201,7 @@ def __init__( self.aggregator = aggregator @staticmethod - def initialize_results( + def create_expanded_df( requested_stratification_names: set[str], registered_stratifications: list[Stratification], ) -> pd.DataFrame: @@ -173,11 +236,25 @@ def initialize_results( return df - def results_gatherer( + def get_complete_stratified_results( self, pop_groups: DataFrameGroupBy, stratifications: Tuple[str, ...], ) -> pd.DataFrame: + """Gather results for this observation. + + Parameters + ---------- + pop_groups + The population grouped by the stratifications. + stratifications + The stratifications to use for the observation. + + Returns + ------- + pd.DataFrame + The results of the observation. + """ df = self._aggregate(pop_groups, self.aggregator_sources, self.aggregator) df = self._format(df) df = self._expand_index(df) @@ -191,6 +268,9 @@ def _aggregate( aggregator_sources: Optional[list[str]], aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], ) -> Union[pd.Series[float], pd.DataFrame]: + """Apply the `aggregator` to the population groups and their + `aggregator_sources` columns. + """ aggregates = ( pop_groups[aggregator_sources].apply(aggregator).fillna(0.0) if aggregator_sources @@ -200,6 +280,9 @@ def _aggregate( @staticmethod def _format(aggregates: Union[pd.Series[float], pd.DataFrame]) -> pd.DataFrame: + """Convert the results to a pandas DataFrame if necessary and ensure the + results column name is 'value'. + """ df = pd.DataFrame(aggregates) if isinstance(aggregates, pd.Series) else aggregates if df.shape[1] == 1: df.rename(columns={df.columns[0]: "value"}, inplace=True) @@ -207,6 +290,7 @@ def _format(aggregates: Union[pd.Series[float], pd.DataFrame]) -> pd.DataFrame: @staticmethod def _expand_index(aggregates: pd.DataFrame) -> pd.DataFrame: + """Include all stratifications in the results by filling missing values with 0.""" if isinstance(aggregates.index, pd.MultiIndex): full_idx = pd.MultiIndex.from_product(aggregates.index.levels) else: @@ -216,17 +300,32 @@ def _expand_index(aggregates: pd.DataFrame) -> pd.DataFrame: class AddingObservation(StratifiedObservation): - """Specific container class for managing stratified observations that add - new results to previous ones at each phase the class is registered to. - This class includes the following attributes: - - `name`: name of the observation and is the measure it is observing - - `pop_filter`: a filter that is applied to the population before the observation is made - - `when`: the phase that the observation is registered to - - `results_formatter`: method that formats the results - - `stratifications`: a tuple of columns for the observation to stratify by - - `aggregator_sources`: a list of the columns to observe - - `aggregator`: a method that aggregates the `aggregator_sources` - - `to_observe`: method that determines whether to observe an event + """Concrete class for observing additive and stratified results. + + The parent class `results_updater` method is explicitly defined and + stratification-specific attributes `aggregator_sources` and `aggregator` are added. + + Attributes + ---------- + name + Name of the observation. It will also be the name of the output results file + for this particular observation. + pop_filter + A Pandas query filter string to filter the population down to the simulants who should + be considered for the observation. + when + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". + results_formatter + Method or function that formats the raw observation results. + stratifications + Optional tuple of column names for the observation to stratify by. + aggregator_sources + List of population view columns to be used in the `aggregator`. + aggregator + Method or function that computes the quantity for this observation. + to_observe + Method or function that determines whether to perform an observation on this Event. """ def __init__( @@ -256,6 +355,13 @@ def __init__( def add_results( existing_results: pd.DataFrame, new_observations: pd.DataFrame ) -> pd.DataFrame: + """Add newly-observed results to the existing results. + + Notes + ----- + If the new observations contain columns not present in the existing results, + the columns are added to the DataFrame and initialized with 0.0s. + """ updated_results = existing_results.copy() # Look for extra columns in the new_observations and initialize with 0. extra_cols = [ @@ -269,16 +375,28 @@ def add_results( class ConcatenatingObservation(UnstratifiedObservation): - """Specific container class for managing observations that concatenate - new results to previous ones at each phase the class is registered to. - Note that this class does not support stratifications. - This class includes the following attributes: - - `name`: name of the observation and is the measure it is observing - - `pop_filter`: a filter that is applied to the population before the observation is made - - `when`: the phase that the observation is registered to - - `included_columns`: the columns to include in the observation - - `results_formatter`: method that formats the results - - `to_observe`: method that determines whether to observe an event + """Concrete class for observing concatenating (and by extension, unstratified) results. + + The parent class `results_gatherer` and `results_updater` methods are explicitly + defined and attribute `included_columns` is added. + + Attributes + ---------- + name + Name of the observation. It will also be the name of the output results file + for this particular observation. + pop_filter + A Pandas query filter string to filter the population down to the simulants who should + be considered for the observation. + when + String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". + included_columns + Columns to include in the observation + results_formatter + Method or function that formats the raw observation results. + to_observe + Method or function that determines whether to perform an observation on this Event. """ def __init__( @@ -294,20 +412,22 @@ def __init__( name=name, pop_filter=pop_filter, when=when, - results_gatherer=self.results_gatherer, + results_gatherer=self.get_results_of_interest, results_updater=self.concatenate_results, results_formatter=results_formatter, to_observe=to_observe, ) self.included_columns = included_columns + def get_results_of_interest(self, pop: pd.DataFrame) -> pd.DataFrame: + """Return the population with only the `included_columns`.""" + return pop[self.included_columns] + @staticmethod def concatenate_results( existing_results: pd.DataFrame, new_observations: pd.DataFrame ) -> pd.DataFrame: + """Concatenate the existing results with the new observations.""" if existing_results.empty: return new_observations return pd.concat([existing_results, new_observations], axis=0).reset_index(drop=True) - - def results_gatherer(self, pop: pd.DataFrame) -> pd.DataFrame: - return pop[self.included_columns] diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 417173e0b..e0ddab7b7 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -1,3 +1,9 @@ +""" +========= +Observers +========= +""" + from abc import ABC, abstractmethod from typing import Any, Dict @@ -7,10 +13,10 @@ class Observer(Component, ABC): """An abstract base class intended to be subclassed by observer components. - The primary purpose of this class is to provide attributes required by - the subclass `report` method. - Note that a `register_observation` method must be defined in the subclass. + Notes + ----- + A `register_observation` method must be defined in the subclass. """ def __init__(self) -> None: @@ -29,6 +35,7 @@ def configuration_defaults(self) -> Dict[str, Any]: } def get_configuration_name(self) -> str: + """Return the name of a concrete observer for use in the configuration""" return self.name.split("_observer")[0] @abstractmethod @@ -37,12 +44,13 @@ def register_observations(self, builder: Builder) -> None: pass def setup_component(self, builder: Builder) -> None: + """Set up the observer component.""" super().setup_component(builder) self.register_observations(builder) - self.get_formatter_attributes(builder) + self.set_results_dir(builder) - def get_formatter_attributes(self, builder: Builder) -> None: - """Define commonly-used attributes for reporting.""" + def set_results_dir(self, builder: Builder) -> None: + """Define the results directory from the configuration.""" self.results_dir = ( builder.configuration.to_dict() .get("output_data", {}) diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index ba8452763..c16ef6b66 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -1,45 +1,64 @@ +""" +================ +Stratifications +================ +""" + from __future__ import annotations from dataclasses import dataclass -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import pandas as pd from pandas.api.types import CategoricalDtype +from vivarium.types import ScalarValue + STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values" @dataclass class Stratification: - """Class for stratifying observed quantities by specified characteristics + """Class for stratifying observed quantities by specified characteristics. Each Stratification represents a set of mutually exclusive and collectively exhaustive categories into which simulants can be assigned. - The `Stratification` class has six fields: `name`, `sources`, `mapper`, - `categories`, `excluded_categories`, and `is_vectorized`. The `name` is the - name of the column created by the mapper. The `sources` is a list of columns - in the extended state table that are the inputs to the mapper function. Simulants - will later be grouped by this column (or these columns) during stratification. - `categories` is the total set of values that the mapper can output. - `excluded_categories` are values that have been requested to be excluded (and - already removed) from `categories`. The `mapper` is the method that transforms the source - to the name column. The method produces an output column by calling the mapper on the source - columns. If the mapper is `None`, the default identity mapper is used. If - the mapper is not vectorized this is performed by using `pd.apply`. - Finally, `is_vectorized` is a boolean parameter that signifies whether - mapper function is applied to a single simulant (`False`) or to the whole - population (`True`). - `Stratification` also has a `__call__()` method. The method produces an output column by calling the mapper on the source columns. + + Attributes + ---------- + name + Name of the stratification. + sources + A list of the columns and values needed as input for the `mapper`. + categories + Exhaustive list of all possible stratification values. + excluded_categories + List of possible stratification values to exclude from results processing. + If None (the default), will use exclusions as defined in the configuration. + mapper + A callable that maps the columns and value pipelines specified by the + `requires_columns` and `requires_values` arguments to the stratification + categories. It can either map the entire population or an individual + simulant. A simulation will fail if the `mapper` ever produces an invalid + value. + is_vectorized + True if the `mapper` function will map the entire population, and False + if it will only map a single simulant. """ name: str sources: List[str] categories: List[str] excluded_categories: List[str] - mapper: Optional[Callable[[Union[pd.Series[str], pd.DataFrame]], pd.Series[str]]] = None + mapper: Optional[ + Union[ + Callable[[Union[pd.Series, pd.DataFrame]], pd.Series[str]], + Callable[[ScalarValue], str], + ] + ] is_vectorized: bool = False def __str__(self) -> str: @@ -49,11 +68,22 @@ def __str__(self) -> str: ) def __post_init__(self) -> None: + """Assign a default `mapper` if none was provided and check for non-empty + `categories` and `sources` otherwise. + + Raises + ------ + ValueError + - If no mapper is provided and the number of sources is not 1. + - If the categories argument is empty. + - If the sources argument is empty. + """ if self.mapper is None: if len(self.sources) != 1: raise ValueError( - f"No mapper provided for stratification {self.name} with " - f"{len(self.sources)} stratification sources." + f"No mapper but {len(self.sources)} stratification sources are " + f"provided for stratification {self.name}. The list of sources " + "must be of length 1 if no mapper is provided." ) self.mapper = self._default_mapper self.is_vectorized = True @@ -62,11 +92,26 @@ def __post_init__(self) -> None: if not self.sources: raise ValueError("The sources argument must be non-empty.") - def __call__(self, population: pd.DataFrame) -> pd.Series[str]: - """Apply the mapper to the population 'sources' columns and add the result - to the population. Any excluded categories (which have already been removed - from self.categories) will be converted to NaNs in the new column - and dropped later at the observation level. + def stratify(self, population: pd.DataFrame) -> pd.Series[str]: + """Apply the mapper to the population `sources` columns to create a new + pandas Series to be added to the population. Any excluded categories + (which have already been removed from self.categories) will be converted + to NaNs in the new column and dropped later at the observation level. + + Parameters + ---------- + population + A pandas DataFrame containing the data to be stratified. + + Returns + ------- + pd.Series[str] + A pandas Series containing the mapped values to be used for stratifying. + + Raises + ------ + ValueError + If the mapper returns any values not in `categories` or `excluded_categories`. """ if self.is_vectorized: mapped_column = self.mapper(population[self.sources]) @@ -91,6 +136,22 @@ def __call__(self, population: pd.DataFrame) -> pd.Series[str]: @staticmethod def _default_mapper(pop: pd.DataFrame) -> pd.Series[str]: + """Default stratification mapper that squeezes a DataFrame to a Series. + + Parameters + ---------- + pop + A pandas DataFrame containing the data to be stratified. + + Returns + ------- + pd.Series[str] + A pandas Series containing the data to be stratified. + + Notes + ----- + The input DataFrame is guaranteeed to have a single column. + """ return pop.squeeze(axis=1) diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index e29bd7564..e377d6ebd 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from vivarium.framework.engine import Builder from vivarium.framework.population import PopulationView - from vivarium.framework.time import Time + from vivarium.types import Time def _next_state( diff --git a/src/vivarium/framework/time.py b/src/vivarium/framework/time.py index ce7ef9661..15da5e8cd 100644 --- a/src/vivarium/framework/time.py +++ b/src/vivarium/framework/time.py @@ -11,14 +11,14 @@ """ -from datetime import datetime, timedelta from functools import partial -from numbers import Number -from typing import TYPE_CHECKING, Callable, List, Union +from typing import TYPE_CHECKING, Callable, List import numpy as np import pandas as pd +from vivarium.types import NumberLike, Time, Timedelta + if TYPE_CHECKING: from vivarium.framework.engine import Builder from vivarium.framework.population.population_view import PopulationView @@ -28,10 +28,6 @@ from vivarium.framework.values import list_combiner from vivarium.manager import Manager -Time = Union[pd.Timestamp, datetime, Number] -Timedelta = Union[pd.Timedelta, timedelta, Number] -NumberLike = Union[np.ndarray, pd.Series, pd.DataFrame, Number] - class SimulationClock(Manager): """A base clock that includes global clock and a pandas series of clocks for each simulant""" diff --git a/src/vivarium/framework/values.py b/src/vivarium/framework/values.py index 679bac768..652a2a25f 100644 --- a/src/vivarium/framework/values.py +++ b/src/vivarium/framework/values.py @@ -14,18 +14,14 @@ """ from collections import defaultdict -from numbers import Number -from typing import Any, Callable, Iterable, List, Tuple, Union +from typing import Any, Callable, Iterable, List, Tuple -import numpy as np import pandas as pd from vivarium.exceptions import VivariumError from vivarium.framework.utilities import from_yearly from vivarium.manager import Manager - -# Supports standard algebraic operations with scalar values. -NumberLike = Union[np.ndarray, pd.Series, pd.DataFrame, Number] +from vivarium.types import NumberLike class DynamicValueError(VivariumError): diff --git a/src/vivarium/interface/interactive.py b/src/vivarium/interface/interactive.py index 9e5d67324..c1462a2b2 100644 --- a/src/vivarium/interface/interactive.py +++ b/src/vivarium/interface/interactive.py @@ -19,9 +19,9 @@ import pandas as pd from vivarium.framework.engine import SimulationContext -from vivarium.framework.time import Time, Timedelta from vivarium.framework.values import Pipeline from vivarium.interface.utilities import log_progress, run_from_ipython +from vivarium.types import Time, Timedelta class InteractiveContext(SimulationContext): diff --git a/src/vivarium/types.py b/src/vivarium/types.py new file mode 100644 index 000000000..404ce9e10 --- /dev/null +++ b/src/vivarium/types.py @@ -0,0 +1,12 @@ +from datetime import datetime, timedelta +from numbers import Number +from typing import Union + +import numpy as np +import pandas as pd + +ScalarValue = Union[Number, timedelta, datetime] +LookupTableData = Union[ScalarValue, pd.DataFrame, list[ScalarValue], tuple[ScalarValue]] +NumberLike = Union[np.ndarray, pd.Series, pd.DataFrame, Number] +Time = Union[pd.Timestamp, datetime, Number] +Timedelta = Union[pd.Timedelta, timedelta, Number] diff --git a/tests/framework/components/test_manager.py b/tests/framework/components/test_manager.py index 81ef509f4..e26e1039e 100644 --- a/tests/framework/components/test_manager.py +++ b/tests/framework/components/test_manager.py @@ -174,7 +174,7 @@ def nest(start, depth): def test_setup_components(mocker): builder = mocker.Mock() builder.configuration = {} - mocker.patch("vivarium.framework.results.observer.Observer.get_formatter_attributes") + mocker.patch("vivarium.framework.results.observer.Observer.set_results_dir") mock_a = MockComponentA("test_a") mock_b = MockComponentB("test_b") components = OrderedComponentSet(mock_a, mock_b) diff --git a/tests/framework/results/test_interface.py b/tests/framework/results/test_interface.py index fe5036ead..6261b455d 100644 --- a/tests/framework/results/test_interface.py +++ b/tests/framework/results/test_interface.py @@ -71,12 +71,18 @@ def _silly_mapper(): assert stratification.is_vectorized is False -def test_register_binned_stratification(mocker): +@pytest.mark.parametrize( + "target, target_type", [("some-column", "column"), ("some-value", "value")] +) +def test_register_binned_stratification_foo(target, target_type, mocker): mgr = ResultsManager() mgr.logger = logger builder = mocker.Mock() - mgr._results_context.setup(builder) + mocker.patch.object(builder, "value.get_value") + builder.value.get_value = MethodType(mock_get_value, builder) + mgr.setup(builder) + # mgr._results_context.setup(builder) # Check pre-registration stratifications and manager required columns/values assert len(mgr._results_context.stratifications) == 0 @@ -84,26 +90,30 @@ def test_register_binned_stratification(mocker): assert len(mgr._required_values) == 0 mgr.register_binned_stratification( - target="some-column-to-bin", + target=target, binned_column="new-binned-column", bin_edges=[1, 2, 3], labels=["1_to_2", "2_to_3"], excluded_categories=["2_to_3"], - target_type="column", + target_type=target_type, some_kwarg="some-kwarg", some_other_kwarg="some-other-kwarg", ) # Check that manager required columns/values have been updated - assert mgr._required_columns == {"tracked", "some-column-to-bin"} - assert len(mgr._required_values) == 0 + assert ( + mgr._required_columns == {"tracked", target} + if target_type == "column" + else {"tracked"} + ) + assert mgr._required_values == ({target} if target_type == "value" else set()) # Check stratification registration stratifications = mgr._results_context.stratifications assert len(stratifications) == 1 stratification = stratifications[0] assert stratification.name == "new-binned-column" - assert stratification.sources == ["some-column-to-bin"] + assert stratification.sources == [target] assert stratification.categories == ["1_to_2"] assert stratification.excluded_categories == ["2_to_3"] # Cannot access the mapper because it's in local scope, so check __repr__ diff --git a/tests/framework/results/test_observer.py b/tests/framework/results/test_observer.py index 0ad4d1154..b0929f7ea 100644 --- a/tests/framework/results/test_observer.py +++ b/tests/framework/results/test_observer.py @@ -35,7 +35,7 @@ def test_observer_instantiation(): (True, None), ], ) -def test_get_formatter_attributes(is_interactive, results_dir, mocker): +def test_set_results_dir(is_interactive, results_dir, mocker): builder = mocker.Mock() if is_interactive: builder.configuration = LayeredConfigTree() @@ -47,6 +47,6 @@ def test_get_formatter_attributes(is_interactive, results_dir, mocker): ) observer = TestObserver() - observer.get_formatter_attributes(builder) + observer.set_results_dir(builder) assert observer.results_dir == results_dir diff --git a/tests/framework/results/test_stratification.py b/tests/framework/results/test_stratification.py index 00e47f58d..42bf82d7a 100644 --- a/tests/framework/results/test_stratification.py +++ b/tests/framework/results/test_stratification.py @@ -45,7 +45,7 @@ def test_stratification(mapper, is_vectorized): mapper=mapper, is_vectorized=is_vectorized, ) - output = my_stratification(STUDENT_TABLE) + output = my_stratification.stratify(STUDENT_TABLE) assert output.eq(STUDENT_HOUSES).all() @@ -56,13 +56,19 @@ def test_stratification(mapper, is_vectorized): [], HOUSE_CATEGORIES, None, - f"No mapper provided for stratification {NAME} with 0 stratification sources.", + ( + f"No mapper but 0 stratification sources are provided for stratification {NAME}. " + "The list of sources must be of length 1 if no mapper is provided." + ), ), ( NAME_COLUMNS, HOUSE_CATEGORIES, None, - f"No mapper provided for stratification {NAME} with {len(NAME_COLUMNS)} stratification sources.", + ( + f"No mapper but {len(NAME_COLUMNS)} stratification sources are provided for stratification {NAME}. " + "The list of sources must be of length 1 if no mapper is provided." + ), ), ( [], @@ -143,7 +149,7 @@ def test_stratification_call_raises( NAME, sources, HOUSE_CATEGORIES, [], mapper, is_vectorized ) with pytest.raises(expected_exception, match=re.escape(error_match)): - my_stratification(STUDENT_TABLE) + my_stratification.stratify(STUDENT_TABLE) @pytest.mark.parametrize("default_stratifications", [["age", "sex"], ["age"], []]) diff --git a/tests/framework/test_state_machine.py b/tests/framework/test_state_machine.py index 3aa2f311a..687f268c2 100644 --- a/tests/framework/test_state_machine.py +++ b/tests/framework/test_state_machine.py @@ -6,7 +6,7 @@ from vivarium import Component, InteractiveContext from vivarium.framework.population import SimulantData from vivarium.framework.state_machine import Machine, State, Transition -from vivarium.framework.time import Time +from vivarium.types import Time def _population_fixture(column, initial_value):