Skip to content

Commit

Permalink
Mypy Stratifications.py (#487)
Browse files Browse the repository at this point in the history
* make dataclass a normal class

* fix tests

* mostly just ignore the issue

* remove protocol

* revert change

* rename

* finish rename

* update docstring

* revert to adding user defined function

* add nitpick exceptions

* private method

* missed a reference

* add comment

* add more nitpicks

* try not using the pandas prefix

* change back to dataclass

* add back docstring

* add docstring

* Delete vivarium.code-workspace

* change names

* change vector to vectorized
  • Loading branch information
patricktnast committed Sep 25, 2024
1 parent ede33e7 commit c12782f
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 42 deletions.
3 changes: 3 additions & 0 deletions docs/nitpick-exceptions
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pandas
py:class CategoricalDtype
py:class pandas.core.indexes.base.Index
py:class pandas.core.indexes.multi.MultiIndex
py:class pandas._libs.tslibs.timestamps.Timestamp
Expand Down Expand Up @@ -26,6 +27,8 @@ py:class ClockTime
py:class Time
py:class ClockStepSize
py:class Timedelta
py:class VectorMapper
py:class ScalarMapper
py:exc ResultsConfigurationError
py:exc vivarium.framework.results.exceptions.ResultsConfigurationError

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ exclude = [
'src/vivarium/framework/results/manager.py',
'src/vivarium/framework/results/observation.py',
'src/vivarium/framework/results/observer.py',
'src/vivarium/framework/results/stratification.py',
'src/vivarium/framework/state_machine.py',
'src/vivarium/framework/time.py',
'src/vivarium/framework/values.py',
Expand Down
81 changes: 42 additions & 39 deletions src/vivarium/framework/results/stratification.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,43 @@
# mypy: ignore-errors
"""
===============
Stratifications
===============
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, List, Optional, Union
from typing import Any, Callable

import pandas as pd
from pandas.api.types import CategoricalDtype

from vivarium.types import ScalarValue

STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values"

# TODO: Parameterizing pandas objects fails below python 3.12
VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg]
ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]


@dataclass
class Stratification:
"""Class for stratifying observed quantities by specified characteristics.
Each Stratification represents a set of mutually exclusive and collectively
exhaustive categories into which simulants can be assigned.
This class includes a :meth:`stratify <stratify>` method that produces an
output column by calling the mapper on the source columns.
"""

name: str
"""Name of the stratification."""
sources: List[str]
sources: list[str]
"""A list of the columns and values needed as input for the `mapper`."""
categories: List[str]
categories: list[str]
"""Exhaustive list of all possible stratification values."""
excluded_categories: List[str]
excluded_categories: list[str]
"""List of possible stratification values to exclude from results processing.
If None (the default), will use exclusions as defined in the configuration."""
mapper: Optional[
Union[
Callable[[Union[pd.Series, pd.DataFrame]], pd.Series],
Callable[[ScalarValue], str],
]
]
mapper: VectorMapper | ScalarMapper | None
"""A callable that maps the columns and value pipelines specified by the
`requires_columns` and `requires_values` arguments to the stratification
categories. It can either map the entire population or an individual
Expand All @@ -56,13 +50,12 @@ class Stratification:
def __str__(self) -> str:
return (
f"Stratification '{self.name}' with sources {self.sources}, "
f"categories {self.categories}, and mapper {self.mapper.__name__}"
f"categories {self.categories}, and mapper {getattr(self.mapper, '__name__', repr(self.mapper))}"
)

def __post_init__(self) -> None:
"""Assign a default `mapper` if none was provided and check for non-empty
`categories` and `sources` otherwise.
Raises
------
ValueError
Expand All @@ -72,21 +65,13 @@ def __post_init__(self) -> None:
ValueError
If the sources argument is empty.
"""
if self.mapper is None:
if len(self.sources) != 1:
raise ValueError(
f"No mapper but {len(self.sources)} stratification sources are "
f"provided for stratification {self.name}. The list of sources "
"must be of length 1 if no mapper is provided."
)
self.mapper = self._default_mapper
self.is_vectorized = True
self.vectorized_mapper = self._get_vectorized_mapper(self.mapper, self.is_vectorized)
if not self.categories:
raise ValueError("The categories argument must be non-empty.")
if not self.sources:
raise ValueError("The sources argument must be non-empty.")

def stratify(self, population: pd.DataFrame) -> pd.Series:
def stratify(self, population: pd.DataFrame) -> pd.Series[CategoricalDtype]:
"""Apply the `mapper` to the population `sources` columns to create a new
Series to be added to the population.
Expand All @@ -108,29 +93,46 @@ def stratify(self, population: pd.DataFrame) -> pd.Series:
ValueError
If the mapper returns any values not in `categories` or `excluded_categories`.
"""
if self.is_vectorized:
mapped_column = self.mapper(population[self.sources])
else:
mapped_column = population[self.sources].apply(self.mapper, axis=1)
mapped_column = self.vectorized_mapper(population[self.sources])
unknown_categories = set(mapped_column) - set(
self.categories + self.excluded_categories
)
# Reduce all nans to a single one
unknown_categories = [cat for cat in unknown_categories if not pd.isna(cat)]
unknown_categories = {cat for cat in unknown_categories if not pd.isna(cat)}
if mapped_column.isna().any():
unknown_categories.append(mapped_column[mapped_column.isna()].iat[0])
unknown_categories.add(mapped_column[mapped_column.isna()].iat[0])
if unknown_categories:
raise ValueError(f"Invalid values mapped to {self.name}: {unknown_categories}")

# Convert the dtype to the allowed categories. Note that this will
# result in Nans for any values in excluded_categories.
mapped_column = mapped_column.astype(
return mapped_column.astype(
CategoricalDtype(categories=self.categories, ordered=True)
)
return mapped_column

def _get_vectorized_mapper(
self,
user_provided_mapper: VectorMapper | ScalarMapper | None,
is_vectorized: bool,
) -> VectorMapper:
"""
Choose a VectorMapper based on the inputted callable mapper.
"""
if user_provided_mapper is None:
if len(self.sources) != 1:
raise ValueError(
f"No mapper but {len(self.sources)} stratification sources are "
f"provided for stratification {self.name}. The list of sources "
"must be of length 1 if no mapper is provided."
)
return self._default_mapper
elif is_vectorized:
return user_provided_mapper # type: ignore [return-value]
else:
return lambda population: population.apply(user_provided_mapper, axis=1)

@staticmethod
def _default_mapper(pop: pd.DataFrame) -> pd.Series:
def _default_mapper(pop: pd.DataFrame) -> pd.Series[Any]:
"""Default stratification mapper that squeezes a DataFrame to a Series.
Parameters
Expand All @@ -144,9 +146,10 @@ def _default_mapper(pop: pd.DataFrame) -> pd.Series:
Notes
-----
The input DataFrame is guaranteeed to have a single column.
The input DataFrame is guaranteed to have a single column.
"""
return pop.squeeze(axis=1)
squeezed_pop: pd.Series[Any] = pop.squeeze(axis=1)
return squeezed_pop


def get_mapped_col_name(col_name: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions tests/framework/results/test_stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_stratification_init_raises(sources, categories, mapper, msg_match):
sorting_hat_bad_mapping,
False,
ValueError,
"Invalid values mapped to hogwarts_house: ['pancakes']",
"Invalid values mapped to hogwarts_house: {'pancakes'}",
),
(
["middle_initial"],
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_stratification_init_raises(sources, categories, mapper, msg_match):
lambda df: pd.Series(np.nan, index=df.index),
True,
ValueError,
f"Invalid values mapped to hogwarts_house: [{np.nan}]",
f"Invalid values mapped to hogwarts_house: {{{np.nan}}}",
),
],
ids=[
Expand Down

0 comments on commit c12782f

Please # to comment.