From c12782fe28ab5fcebda9d443cfafc425afbbbbe9 Mon Sep 17 00:00:00 2001 From: patricktnast <130876799+patricktnast@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:55:47 -0700 Subject: [PATCH] Mypy Stratifications.py (#487) * make dataclass a normal class * fix tests * mostly just ignore the issue * remove protocol * revert change * rename * finish rename * update docstring * revert to adding user defined function * add nitpick exceptions * private method * missed a reference * add comment * add more nitpicks * try not using the pandas prefix * change back to dataclass * add back docstring * add docstring * Delete vivarium.code-workspace * change names * change vector to vectorized --- docs/nitpick-exceptions | 3 + pyproject.toml | 1 - .../framework/results/stratification.py | 81 ++++++++++--------- .../framework/results/test_stratification.py | 4 +- 4 files changed, 47 insertions(+), 42 deletions(-) diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index 5b4780ed3..8119c79b6 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -1,4 +1,5 @@ # pandas +py:class CategoricalDtype py:class pandas.core.indexes.base.Index py:class pandas.core.indexes.multi.MultiIndex py:class pandas._libs.tslibs.timestamps.Timestamp @@ -26,6 +27,8 @@ py:class ClockTime py:class Time py:class ClockStepSize py:class Timedelta +py:class VectorMapper +py:class ScalarMapper py:exc ResultsConfigurationError py:exc vivarium.framework.results.exceptions.ResultsConfigurationError diff --git a/pyproject.toml b/pyproject.toml index 319eeb02c..a3c101e2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,6 @@ exclude = [ 'src/vivarium/framework/results/manager.py', 'src/vivarium/framework/results/observation.py', 'src/vivarium/framework/results/observer.py', - 'src/vivarium/framework/results/stratification.py', 'src/vivarium/framework/state_machine.py', 'src/vivarium/framework/time.py', 'src/vivarium/framework/values.py', diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index 463dcae11..7be52813b 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -1,49 +1,43 @@ -# mypy: ignore-errors """ =============== Stratifications =============== """ +from __future__ import annotations from dataclasses import dataclass -from typing import Callable, List, Optional, Union +from typing import Any, Callable import pandas as pd from pandas.api.types import CategoricalDtype -from vivarium.types import ScalarValue - STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values" +# TODO: Parameterizing pandas objects fails below python 3.12 +VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg] +ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg] + @dataclass class Stratification: """Class for stratifying observed quantities by specified characteristics. - Each Stratification represents a set of mutually exclusive and collectively exhaustive categories into which simulants can be assigned. - This class includes a :meth:`stratify ` method that produces an output column by calling the mapper on the source columns. - """ name: str """Name of the stratification.""" - sources: List[str] + sources: list[str] """A list of the columns and values needed as input for the `mapper`.""" - categories: List[str] + categories: list[str] """Exhaustive list of all possible stratification values.""" - excluded_categories: List[str] + excluded_categories: list[str] """List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration.""" - mapper: Optional[ - Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], - Callable[[ScalarValue], str], - ] - ] + mapper: VectorMapper | ScalarMapper | None """A callable that maps the columns and value pipelines specified by the `requires_columns` and `requires_values` arguments to the stratification categories. It can either map the entire population or an individual @@ -56,13 +50,12 @@ class Stratification: def __str__(self) -> str: return ( f"Stratification '{self.name}' with sources {self.sources}, " - f"categories {self.categories}, and mapper {self.mapper.__name__}" + f"categories {self.categories}, and mapper {getattr(self.mapper, '__name__', repr(self.mapper))}" ) def __post_init__(self) -> None: """Assign a default `mapper` if none was provided and check for non-empty `categories` and `sources` otherwise. - Raises ------ ValueError @@ -72,21 +65,13 @@ def __post_init__(self) -> None: ValueError If the sources argument is empty. """ - if self.mapper is None: - if len(self.sources) != 1: - raise ValueError( - f"No mapper but {len(self.sources)} stratification sources are " - f"provided for stratification {self.name}. The list of sources " - "must be of length 1 if no mapper is provided." - ) - self.mapper = self._default_mapper - self.is_vectorized = True + self.vectorized_mapper = self._get_vectorized_mapper(self.mapper, self.is_vectorized) if not self.categories: raise ValueError("The categories argument must be non-empty.") if not self.sources: raise ValueError("The sources argument must be non-empty.") - def stratify(self, population: pd.DataFrame) -> pd.Series: + def stratify(self, population: pd.DataFrame) -> pd.Series[CategoricalDtype]: """Apply the `mapper` to the population `sources` columns to create a new Series to be added to the population. @@ -108,29 +93,46 @@ def stratify(self, population: pd.DataFrame) -> pd.Series: ValueError If the mapper returns any values not in `categories` or `excluded_categories`. """ - if self.is_vectorized: - mapped_column = self.mapper(population[self.sources]) - else: - mapped_column = population[self.sources].apply(self.mapper, axis=1) + mapped_column = self.vectorized_mapper(population[self.sources]) unknown_categories = set(mapped_column) - set( self.categories + self.excluded_categories ) # Reduce all nans to a single one - unknown_categories = [cat for cat in unknown_categories if not pd.isna(cat)] + unknown_categories = {cat for cat in unknown_categories if not pd.isna(cat)} if mapped_column.isna().any(): - unknown_categories.append(mapped_column[mapped_column.isna()].iat[0]) + unknown_categories.add(mapped_column[mapped_column.isna()].iat[0]) if unknown_categories: raise ValueError(f"Invalid values mapped to {self.name}: {unknown_categories}") # Convert the dtype to the allowed categories. Note that this will # result in Nans for any values in excluded_categories. - mapped_column = mapped_column.astype( + return mapped_column.astype( CategoricalDtype(categories=self.categories, ordered=True) ) - return mapped_column + + def _get_vectorized_mapper( + self, + user_provided_mapper: VectorMapper | ScalarMapper | None, + is_vectorized: bool, + ) -> VectorMapper: + """ + Choose a VectorMapper based on the inputted callable mapper. + """ + if user_provided_mapper is None: + if len(self.sources) != 1: + raise ValueError( + f"No mapper but {len(self.sources)} stratification sources are " + f"provided for stratification {self.name}. The list of sources " + "must be of length 1 if no mapper is provided." + ) + return self._default_mapper + elif is_vectorized: + return user_provided_mapper # type: ignore [return-value] + else: + return lambda population: population.apply(user_provided_mapper, axis=1) @staticmethod - def _default_mapper(pop: pd.DataFrame) -> pd.Series: + def _default_mapper(pop: pd.DataFrame) -> pd.Series[Any]: """Default stratification mapper that squeezes a DataFrame to a Series. Parameters @@ -144,9 +146,10 @@ def _default_mapper(pop: pd.DataFrame) -> pd.Series: Notes ----- - The input DataFrame is guaranteeed to have a single column. + The input DataFrame is guaranteed to have a single column. """ - return pop.squeeze(axis=1) + squeezed_pop: pd.Series[Any] = pop.squeeze(axis=1) + return squeezed_pop def get_mapped_col_name(col_name: str) -> str: diff --git a/tests/framework/results/test_stratification.py b/tests/framework/results/test_stratification.py index 42bf82d7a..a63dd26be 100644 --- a/tests/framework/results/test_stratification.py +++ b/tests/framework/results/test_stratification.py @@ -103,7 +103,7 @@ def test_stratification_init_raises(sources, categories, mapper, msg_match): sorting_hat_bad_mapping, False, ValueError, - "Invalid values mapped to hogwarts_house: ['pancakes']", + "Invalid values mapped to hogwarts_house: {'pancakes'}", ), ( ["middle_initial"], @@ -131,7 +131,7 @@ def test_stratification_init_raises(sources, categories, mapper, msg_match): lambda df: pd.Series(np.nan, index=df.index), True, ValueError, - f"Invalid values mapped to hogwarts_house: [{np.nan}]", + f"Invalid values mapped to hogwarts_house: {{{np.nan}}}", ), ], ids=[