From 2badecb4b259289b0889a7b528e1a071db651c63 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Mon, 19 Aug 2024 14:29:31 -0600 Subject: [PATCH] Fixes for doc building; minor docstring updates (#467) --- .gitignore | 3 + docs/nitpick-exceptions | 3 + .../framework/results/context.rst | 1 + .../api_reference/framework/results/index.rst | 11 +++ .../framework/results/interface.rst | 1 + .../framework/results/manager.rst | 1 + .../framework/results/observation.rst | 1 + .../framework/results/observer.rst | 1 + .../framework/results/stratification.rst | 1 + docs/source/concepts/results.rst | 9 +++ src/vivarium/framework/artifact/artifact.py | 4 +- src/vivarium/framework/artifact/hdf.py | 2 +- .../framework/lookup/interpolation.py | 4 +- .../framework/population/population_view.py | 2 +- .../framework/randomness/index_map.py | 6 +- src/vivarium/framework/results/context.py | 20 ++--- src/vivarium/framework/results/interface.py | 16 ++-- src/vivarium/framework/results/manager.py | 12 ++- src/vivarium/framework/results/observation.py | 78 ++++++++++--------- src/vivarium/framework/results/observer.py | 8 ++ .../framework/results/stratification.py | 71 ++++++++--------- tests/framework/results/helpers.py | 2 +- tests/framework/results/test_context.py | 2 +- 23 files changed, 144 insertions(+), 115 deletions(-) create mode 100644 docs/source/api_reference/framework/results/context.rst create mode 100644 docs/source/api_reference/framework/results/index.rst create mode 100644 docs/source/api_reference/framework/results/interface.rst create mode 100644 docs/source/api_reference/framework/results/manager.rst create mode 100644 docs/source/api_reference/framework/results/observation.rst create mode 100644 docs/source/api_reference/framework/results/observer.rst create mode 100644 docs/source/api_reference/framework/results/stratification.rst create mode 100644 docs/source/concepts/results.rst diff --git a/.gitignore b/.gitignore index 26ffd0b77..d33df0a2d 100644 --- a/.gitignore +++ b/.gitignore @@ -115,3 +115,6 @@ notebooks/ # Version file src/vivarium/_version.py + +# macOS +*.DS_Store diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index 32abf70dd..ee1218c39 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -6,6 +6,7 @@ py:class pandas._libs.tslibs.timedeltas.Timedelta py:class pandas.core.frame.DataFrame py:class pandas.core.series.Series py:class pandas.core.generic.PandasObject +py:class pandas.core.groupby.generic.DataFrameGroupBy #scipy py:class scipy.stats._distn_infrastructure.rv_continuous @@ -21,6 +22,8 @@ py:class Time py:class vivarium.framework.time.Time py:class Timedelta py:class vivarium.framework.time.Timedelta +py:exc ResultsConfigurationError +py:exc vivarium.framework.results.exceptions.ResultsConfigurationError # layered_config_tree py:class layered_config_tree.main.LayeredConfigTree diff --git a/docs/source/api_reference/framework/results/context.rst b/docs/source/api_reference/framework/results/context.rst new file mode 100644 index 000000000..db9e3a7ef --- /dev/null +++ b/docs/source/api_reference/framework/results/context.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.results.context \ No newline at end of file diff --git a/docs/source/api_reference/framework/results/index.rst b/docs/source/api_reference/framework/results/index.rst new file mode 100644 index 000000000..ff60e2407 --- /dev/null +++ b/docs/source/api_reference/framework/results/index.rst @@ -0,0 +1,11 @@ +================== +Results Processing +================== + +.. automodule:: vivarium.framework.results + +.. toctree:: + :maxdepth: 1 + :glob: + + * \ No newline at end of file diff --git a/docs/source/api_reference/framework/results/interface.rst b/docs/source/api_reference/framework/results/interface.rst new file mode 100644 index 000000000..291c89916 --- /dev/null +++ b/docs/source/api_reference/framework/results/interface.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.results.interface \ No newline at end of file diff --git a/docs/source/api_reference/framework/results/manager.rst b/docs/source/api_reference/framework/results/manager.rst new file mode 100644 index 000000000..40bb0a84b --- /dev/null +++ b/docs/source/api_reference/framework/results/manager.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.results.manager \ No newline at end of file diff --git a/docs/source/api_reference/framework/results/observation.rst b/docs/source/api_reference/framework/results/observation.rst new file mode 100644 index 000000000..8498e429d --- /dev/null +++ b/docs/source/api_reference/framework/results/observation.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.results.observation \ No newline at end of file diff --git a/docs/source/api_reference/framework/results/observer.rst b/docs/source/api_reference/framework/results/observer.rst new file mode 100644 index 000000000..ebea7dd92 --- /dev/null +++ b/docs/source/api_reference/framework/results/observer.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.results.observer \ No newline at end of file diff --git a/docs/source/api_reference/framework/results/stratification.rst b/docs/source/api_reference/framework/results/stratification.rst new file mode 100644 index 000000000..b3c13ed1c --- /dev/null +++ b/docs/source/api_reference/framework/results/stratification.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.results.stratification \ No newline at end of file diff --git a/docs/source/concepts/results.rst b/docs/source/concepts/results.rst new file mode 100644 index 000000000..990e78f04 --- /dev/null +++ b/docs/source/concepts/results.rst @@ -0,0 +1,9 @@ +.. _results_concept: + +================== +Simulation Results +================== + +.. todo:: + + Everything here. \ No newline at end of file diff --git a/src/vivarium/framework/artifact/artifact.py b/src/vivarium/framework/artifact/artifact.py index 859c9fafe..e98f96981 100644 --- a/src/vivarium/framework/artifact/artifact.py +++ b/src/vivarium/framework/artifact/artifact.py @@ -37,7 +37,7 @@ def __init__(self, path: Union[str, Path], filter_terms: List[str] = None): The path to the artifact file. filter_terms A set of terms suitable for usage with the ``where`` kwarg - for :func:`pd.read_hdf`. + for :func:`pandas.read_hdf`. """ self._path = Path(path) @@ -76,7 +76,7 @@ def create_hdf_with_keyspace(path: Path): "Attempting to construct an Artifact from a malformed existing file. " "This can occur when constructing an Artifact from an existing file when " "the existing file was generated by some other hdf writing mechanism " - "(e.g. pd.to_hdf) rather than generating the the file using this class " + "(e.g. pandas.to_hdf) rather than generating the the file using this class " "and a non-existent or empty hdf file." ) if not keys: diff --git a/src/vivarium/framework/artifact/hdf.py b/src/vivarium/framework/artifact/hdf.py index 35bd098ea..c57578c40 100644 --- a/src/vivarium/framework/artifact/hdf.py +++ b/src/vivarium/framework/artifact/hdf.py @@ -404,7 +404,7 @@ def _get_valid_filter_terms(filter_terms, colnames): ---------- filter_terms A list of terms formatted so as to be used in the `where` argument of - :func:`pd.read_hdf`. + :func:`pandas.read_hdf`. colnames : A list of column names present in the data that will be filtered. diff --git a/src/vivarium/framework/lookup/interpolation.py b/src/vivarium/framework/lookup/interpolation.py index ce3c6eb66..2898aae23 100644 --- a/src/vivarium/framework/lookup/interpolation.py +++ b/src/vivarium/framework/lookup/interpolation.py @@ -99,7 +99,7 @@ def __call__(self, interpolants: pd.DataFrame) -> pd.DataFrame: Returns ------- - pd.DataFrame + pandas.DataFrame A table with the interpolated values for the given interpolants. """ @@ -316,7 +316,7 @@ def __call__(self, interpolants: pd.DataFrame) -> pd.DataFrame: Returns ------- - pd.DataFrame + pandas.DataFrame A table with the interpolated values for the given interpolants. """ diff --git a/src/vivarium/framework/population/population_view.py b/src/vivarium/framework/population/population_view.py index c6794ea39..cbcdfcdf6 100644 --- a/src/vivarium/framework/population/population_view.py +++ b/src/vivarium/framework/population/population_view.py @@ -475,7 +475,7 @@ def _update_column_and_ensure_dtype( Returns ------- - pd.Series + pandas.Series The column with the provided update applied """ diff --git a/src/vivarium/framework/randomness/index_map.py b/src/vivarium/framework/randomness/index_map.py index f51627148..e7645b1cd 100644 --- a/src/vivarium/framework/randomness/index_map.py +++ b/src/vivarium/framework/randomness/index_map.py @@ -73,7 +73,7 @@ def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.Mul Returns ------- - Tuple[pd.MultiIndex, pd.MultiIndex] + Tuple[pandas.MultiIndex, pandas.MultiIndex] A tuple of the new mapping index and the final mapping index. Both are pandas indices with a level for the index assigned by the population system and additional levels for the key columns associated with the simulant index. The @@ -108,7 +108,7 @@ def _build_final_mapping( Returns ------- - pd.Series + pandas.Series The new mapping incorporating the updates from the new mapping index and resolving collisions. @@ -140,7 +140,7 @@ def _resolve_collisions( Returns ------- - pd.Series + pandas.Series The new mapping incorporating the updates from the new mapping index and resolving collisions. diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 17776b84f..f7d23d0a5 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -4,13 +4,11 @@ =============== """ -from __future__ import annotations - from collections import defaultdict -from typing import Any, Callable, Generator, List, Optional, Tuple, Type, Union +from typing import Callable, Generator, List, Optional, Tuple, Type, Union import pandas as pd -from pandas.core.groupby import DataFrameGroupBy +from pandas.core.groupby.generic import DataFrameGroupBy from vivarium.framework.engine import Builder from vivarium.framework.event import Event @@ -29,8 +27,8 @@ class ResultsContext: This context object is wholly contained by :class:`ResultsManager `. Stratifications and observations can be added to the context through the manager via the - :meth:`vivarium.framework.results.context.ResultsContext.add_stratification` and - :meth:`vivarium.framework.results.context.ResultsContext.register_observation` methods, respectively. + :meth:`add_stratification ` and + :meth:`register_observation ` methods, respectively. Attributes ---------- @@ -101,7 +99,7 @@ def add_stratification( excluded_categories: Optional[List[str]], mapper: Optional[ Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series[str]], + Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], Callable[[ScalarValue], str], ] ], @@ -132,9 +130,11 @@ def add_stratification( Raises ------ ValueError - - If the stratification `name` is already used. - - If there are duplicate `categories`. - - If any `excluded_categories` are not in `categories`. + If the stratification `name` is already used. + ValueError + If there are duplicate `categories`. + ValueError + If any `excluded_categories` are not in `categories`. """ already_used = [ stratification diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index 88fac452b..25eb9264d 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -1,16 +1,14 @@ """ -========================== -Vivarium Results Interface -========================== +================= +Results Interface +================= This module provides a :class:`ResultsInterface ` class with methods to register stratifications and results producers (referred to as "observations") to a simulation. """ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union import pandas as pd @@ -74,7 +72,7 @@ def register_stratification( excluded_categories: Optional[List[str]] = None, mapper: Optional[ Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series[str]], + Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], Callable[[ScalarValue], str], ] ] = None, @@ -182,7 +180,7 @@ def register_stratified_observation( additional_stratifications: List[str] = [], excluded_stratifications: List[str] = [], aggregator_sources: Optional[List[str]] = None, - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]] = len, + aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers a stratified observation to the results system. @@ -323,7 +321,7 @@ def register_adding_observation( additional_stratifications: List[str] = [], excluded_stratifications: List[str] = [], aggregator_sources: Optional[List[str]] = None, - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]] = len, + aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers an adding observation to the results system; that is, diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index 631be1f59..25807899c 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -4,11 +4,9 @@ ====================== """ -from __future__ import annotations - from collections import defaultdict from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import pandas as pd @@ -31,7 +29,7 @@ class ResultsManager(Manager): """Backend manager object for the results management system. This class contains the public methods used by the :class:`ResultsInterface ` - to register stratifications and observations as well as the :method:`get_results` + to register stratifications and observations as well as the :meth:`get_results ` method used to retrieve formatted results by the :class:`ResultsContext `. """ @@ -66,7 +64,7 @@ def get_results(self) -> Dict[str, pd.DataFrame]: Returns ------- - Dict[str, pd.DataFrame] + Dict[str, pandas.DataFrame] A dictionary of formatted results for each measure. """ formatted = {} @@ -167,7 +165,7 @@ def gather_results(self, lifecycle_phase: str, event: Event) -> None: # Stratification methods # ########################## - def set_default_stratifications(self, builder: Builder) -> None: + def set_default_stratifications(self, builder: "Builder") -> None: """Set the default stratifications for the results context. This passes the default stratifications from the configuration to the @@ -189,7 +187,7 @@ def register_stratification( excluded_categories: Optional[List[str]], mapper: Optional[ Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series[str]], + Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], Callable[[ScalarValue], str], ] ], diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index aaa8802eb..d31062167 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -2,9 +2,20 @@ ============ Observations ============ -""" -from __future__ import annotations +An observation is a class object that records simulation results; they are responsible +for initializing, gathering, updating, and formatting results. + +The provided :class:`BaseObservation` class is an abstract base class that should +be subclassed by concrete observations. While there are no required abstract methods +to define when subclassing, the class does provide common attributes as well +as an `observe` method that determines whether to observe results for a given event. + +At the highest level, an observation can be categorized as either an +:class:`UnstratifiedObservation` or a :class:`StratifiedObservation`. More specialized +implementations of these classes involve defining the various methods +provided as attributes to the parent class. +""" import itertools from abc import ABC @@ -13,7 +24,7 @@ import pandas as pd from pandas.api.types import CategoricalDtype -from pandas.core.groupby import DataFrameGroupBy +from pandas.core.groupby.generic import DataFrameGroupBy from vivarium.framework.event import Event from vivarium.framework.results.stratification import Stratification @@ -23,49 +34,40 @@ @dataclass class BaseObservation(ABC): - """An abstract base dataclass to be inherited by concrete observations. It includes - an :meth:`observe` method that determines whether to observe results for a given event. + """An abstract base dataclass to be inherited by concrete observations. - Attributes - ---------- - name - Name of the observation. It will also be the name of the output results file - for this particular observation. - pop_filter - A Pandas query filter string to filter the population down to the simulants who should - be considered for the observation. - when - String name of the lifecycle phase the observation should happen. Valid values are: - "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - results_initializer - Method or function that initializes the raw observation results - prior to starting the simulation. This could return, for example, an empty - DataFrame or one with a complete set of stratifications as the index and - all values set to 0.0. - results_gatherer - Method or function that gathers the new observation results. - results_updater - Method or function that updates existing raw observation results with newly gathered results. - results_formatter - Method or function that formats the raw observation results. - stratifications - Optional tuple of column names for the observation to stratify by. - to_observe - Method or function that determines whether to perform an observation on this Event. + This class includes an :meth:`observe ` method that determines whether + to observe results for a given event. """ name: str + """Name of the observation. It will also be the name of the output results file + for this particular observation.""" pop_filter: str + """A Pandas query filter string to filter the population down to the simulants + who should be considered for the observation.""" when: str + """String name of the lifecycle phase the observation should happen. Valid values are: + "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".""" results_initializer: Callable[[Iterable[str], Iterable[Stratification]], pd.DataFrame] + """Method or function that initializes the raw observation results + prior to starting the simulation. This could return, for example, an empty + DataFrame or one with a complete set of stratifications as the index and + all values set to 0.0.""" results_gatherer: Union[ Callable[[pd.DataFrame, Sequence[str]], pd.DataFrame], Callable[[pd.DataFrame], pd.DataFrame], ] + """Method or function that gathers the new observation results.""" results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] + """Method or function that updates existing raw observation results with newly + gathered results.""" results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame] + """Method or function that formats the raw observation results.""" stratifications: Optional[Tuple[str]] + """Optional tuple of column names for the observation to stratify by.""" to_observe: Callable[[Event], bool] + """Method or function that determines whether to perform an observation on this Event.""" def observe( self, @@ -73,7 +75,7 @@ def observe( df: Union[pd.DataFrame, DataFrameGroupBy], stratifications: Optional[tuple[str, ...]], ) -> Optional[pd.DataFrame]: - """Determine whether to observe the given event and, if so, gather the results.""" + # """Determine whether to observe the given event and, if so, gather the results.""" if not self.to_observe(event): return None else: @@ -183,7 +185,7 @@ def __init__( results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], stratifications: Tuple[str, ...], aggregator_sources: Optional[list[str]], - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], + aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]], to_observe: Callable[[Event], bool] = lambda event: True, ): super().__init__( @@ -252,7 +254,7 @@ def get_complete_stratified_results( Returns ------- - pd.DataFrame + pandas.DataFrame The results of the observation. """ df = self._aggregate(pop_groups, self.aggregator_sources, self.aggregator) @@ -266,8 +268,8 @@ def get_complete_stratified_results( def _aggregate( pop_groups: DataFrameGroupBy, aggregator_sources: Optional[list[str]], - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], - ) -> Union[pd.Series[float], pd.DataFrame]: + aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]], + ) -> Union[pd.Series, pd.DataFrame]: """Apply the `aggregator` to the population groups and their `aggregator_sources` columns. """ @@ -279,7 +281,7 @@ def _aggregate( return aggregates @staticmethod - def _format(aggregates: Union[pd.Series[float], pd.DataFrame]) -> pd.DataFrame: + def _format(aggregates: Union[pd.Series, pd.DataFrame]) -> pd.DataFrame: """Convert the results to a pandas DataFrame if necessary and ensure the results column name is 'value'. """ @@ -336,7 +338,7 @@ def __init__( results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], stratifications: Tuple[str, ...], aggregator_sources: Optional[list[str]], - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series[float]]], + aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]], to_observe: Callable[[Event], bool] = lambda event: True, ): super().__init__( diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index e0ddab7b7..5106b087b 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -2,6 +2,14 @@ ========= Observers ========= + +An observer is a component that is responsible for registering +:class:`observations ` +to the simulation. + +The provided :class:`Observer` class is an abstract base class that should be subclassed +by concrete observers. Each concrete observer is required to implement a +`register_observations` method that registers all required observations. """ from abc import ABC, abstractmethod diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index c16ef6b66..765a047ac 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -1,13 +1,11 @@ """ -================ +=============== Stratifications -================ +=============== """ -from __future__ import annotations - from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Union +from typing import Callable, List, Optional, Union import pandas as pd from pandas.api.types import CategoricalDtype @@ -24,42 +22,33 @@ class Stratification: Each Stratification represents a set of mutually exclusive and collectively exhaustive categories into which simulants can be assigned. - `Stratification` also has a `__call__()` method. The method produces an + This class includes a :meth:`stratify ` method that produces an output column by calling the mapper on the source columns. - - Attributes - ---------- - name - Name of the stratification. - sources - A list of the columns and values needed as input for the `mapper`. - categories - Exhaustive list of all possible stratification values. - excluded_categories - List of possible stratification values to exclude from results processing. - If None (the default), will use exclusions as defined in the configuration. - mapper - A callable that maps the columns and value pipelines specified by the - `requires_columns` and `requires_values` arguments to the stratification - categories. It can either map the entire population or an individual - simulant. A simulation will fail if the `mapper` ever produces an invalid - value. - is_vectorized - True if the `mapper` function will map the entire population, and False - if it will only map a single simulant. """ name: str + """Name of the stratification.""" sources: List[str] + """A list of the columns and values needed as input for the `mapper`.""" categories: List[str] + """Exhaustive list of all possible stratification values.""" excluded_categories: List[str] + """List of possible stratification values to exclude from results processing. + If None (the default), will use exclusions as defined in the configuration.""" mapper: Optional[ Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series[str]], + Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], Callable[[ScalarValue], str], ] ] + """A callable that maps the columns and value pipelines specified by the + `requires_columns` and `requires_values` arguments to the stratification + categories. It can either map the entire population or an individual + simulant. A simulation will fail if the `mapper` ever produces an invalid + value.""" is_vectorized: bool = False + """True if the `mapper` function will map the entire population, and False + if it will only map a single simulant.""" def __str__(self) -> str: return ( @@ -74,9 +63,11 @@ def __post_init__(self) -> None: Raises ------ ValueError - - If no mapper is provided and the number of sources is not 1. - - If the categories argument is empty. - - If the sources argument is empty. + If no mapper is provided and the number of sources is not 1. + ValueError + If the categories argument is empty. + ValueError + If the sources argument is empty. """ if self.mapper is None: if len(self.sources) != 1: @@ -92,21 +83,21 @@ def __post_init__(self) -> None: if not self.sources: raise ValueError("The sources argument must be non-empty.") - def stratify(self, population: pd.DataFrame) -> pd.Series[str]: + def stratify(self, population: pd.DataFrame) -> pd.Series: """Apply the mapper to the population `sources` columns to create a new - pandas Series to be added to the population. Any excluded categories + Series to be added to the population. Any excluded categories (which have already been removed from self.categories) will be converted to NaNs in the new column and dropped later at the observation level. Parameters ---------- population - A pandas DataFrame containing the data to be stratified. + A DataFrame containing the data to be stratified. Returns ------- - pd.Series[str] - A pandas Series containing the mapped values to be used for stratifying. + pandas.Series + A Series containing the mapped values to be used for stratifying. Raises ------ @@ -135,18 +126,18 @@ def stratify(self, population: pd.DataFrame) -> pd.Series[str]: return mapped_column @staticmethod - def _default_mapper(pop: pd.DataFrame) -> pd.Series[str]: + def _default_mapper(pop: pd.DataFrame) -> pd.Series: """Default stratification mapper that squeezes a DataFrame to a Series. Parameters ---------- pop - A pandas DataFrame containing the data to be stratified. + A DataFrame containing the data to be stratified. Returns ------- - pd.Series[str] - A pandas Series containing the data to be stratified. + pandas.Series + A Series containing the data to be stratified. Notes ----- diff --git a/tests/framework/results/helpers.py b/tests/framework/results/helpers.py index 6ff72e54d..c088587b2 100644 --- a/tests/framework/results/helpers.py +++ b/tests/framework/results/helpers.py @@ -312,7 +312,7 @@ def sorting_hat_bad_mapping(simulant_row: pd.Series) -> str: def verify_stratification_added( stratification_list, name, sources, categories, excluded_categories, mapper, is_vectorized ): - """Verify that a :class: `vivarium.framework.results.stratification.Stratification` is in `stratification_list`""" + """Verify that a Stratification object is in `stratification_list`""" matching_stratification_found = False for stratification in stratification_list: # noqa # big equality check diff --git a/tests/framework/results/test_context.py b/tests/framework/results/test_context.py index 794d957a6..e006f464a 100644 --- a/tests/framework/results/test_context.py +++ b/tests/framework/results/test_context.py @@ -8,7 +8,7 @@ import pytest from layered_config_tree import LayeredConfigTree from loguru import logger -from pandas.core.groupby import DataFrameGroupBy +from pandas.core.groupby.generic import DataFrameGroupBy from tests.framework.results.helpers import ( BASE_POPULATION,