From 630b06d5f7f158b48a8a2604c7edd65b9888ded8 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Tue, 12 Nov 2024 13:17:31 -0800 Subject: [PATCH 1/5] modernize typing --- CHANGELOG.rst | 4 +++ pyproject.toml | 1 - src/vivarium/framework/results/context.py | 39 +++++++++++------------ 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 78d869b6f..4b088aae4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.1.1 - TBD/TBD/TBD** + + - Fix mypy errors in vivarium/framework/results/context.py + **3.1.0 - 11/07/24** - Drop support for python 3.9 diff --git a/pyproject.toml b/pyproject.toml index 768b98786..c10dc069b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ exclude = [ 'src/vivarium/framework/lookup/manager.py', 'src/vivarium/framework/population/manager.py', 'src/vivarium/framework/population/population_view.py', - 'src/vivarium/framework/results/context.py', 'src/vivarium/framework/results/interface.py', 'src/vivarium/framework/results/manager.py', 'src/vivarium/framework/results/observer.py', diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 353cecbae..cabe7cc08 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ =============== Results Context @@ -7,7 +6,8 @@ """ from collections import defaultdict -from typing import Callable, Generator, List, Optional, Tuple, Type, Union +from collections.abc import Callable, Generator +from typing import Type import pandas as pd from pandas.core.groupby.generic import DataFrameGroupBy @@ -52,8 +52,8 @@ class ResultsContext: """ def __init__(self) -> None: - self.default_stratifications: List[str] = [] - self.stratifications: List[Stratification] = [] + self.default_stratifications: list[str] = [] + self.stratifications: list[Stratification] = [] self.excluded_categories: dict[str, list[str]] = {} self.observations: defaultdict = defaultdict(lambda: defaultdict(list)) @@ -73,7 +73,7 @@ def setup(self, builder: Builder) -> None: ) # noinspection PyAttributeOutsideInit - def set_default_stratifications(self, default_grouping_columns: List[str]) -> None: + def set_default_stratifications(self, default_grouping_columns: list[str]) -> None: """Set the default stratifications to be used by stratified observations. Parameters @@ -96,15 +96,14 @@ def set_default_stratifications(self, default_grouping_columns: List[str]) -> No def add_stratification( self, name: str, - sources: List[str], - categories: List[str], - excluded_categories: Optional[List[str]], - mapper: Optional[ - Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], - Callable[[ScalarValue], str], - ] - ], + sources: list[str], + categories: list[str], + excluded_categories: list[str] | None, + mapper: ( + Callable[[pd.Series | pd.DataFrame], pd.Series] + | Callable[[ScalarValue], str] + | None + ), is_vectorized: bool, ) -> None: """Add a stratification to the results context. @@ -242,10 +241,10 @@ def register_observation( def gather_results( self, population: pd.DataFrame, lifecycle_phase: str, event: Event ) -> Generator[ - Tuple[ - Optional[pd.DataFrame], - Optional[str], - Optional[Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]], + tuple[ + pd.DataFrame | None, + str | None, + Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] | None, ], None, None, @@ -317,7 +316,7 @@ def _filter_population( self, population: pd.DataFrame, pop_filter: str, - stratification_names: Optional[tuple[str, ...]], + stratification_names: tuple[str, ...] | None, ) -> pd.DataFrame: """Filter out simulants not to observe.""" pop = population.query(pop_filter) if pop_filter else population.copy() @@ -334,7 +333,7 @@ def _filter_population( @staticmethod def _get_groups( - stratifications: Tuple[str, ...], filtered_pop: pd.DataFrame + stratifications: tuple[str, ...], filtered_pop: pd.DataFrame ) -> DataFrameGroupBy: """Group the population by stratification. From f015a565c4b66228acadb1af34538e61ee0afa7d Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Tue, 12 Nov 2024 16:28:25 -0800 Subject: [PATCH 2/5] Fix mypy errors in framework/results/context.py --- src/vivarium/framework/results/context.py | 23 ++++++++++--------- src/vivarium/framework/results/manager.py | 5 ++-- src/vivarium/framework/results/observation.py | 13 +++++++---- .../framework/results/stratification.py | 7 +++--- src/vivarium/types.py | 4 ++++ 5 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index cabe7cc08..cee07d481 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -5,9 +5,11 @@ """ +from __future__ import annotations + from collections import defaultdict from collections.abc import Callable, Generator -from typing import Type +from typing import Any, Type import pandas as pd from pandas.core.groupby.generic import DataFrameGroupBy @@ -21,7 +23,7 @@ get_mapped_col_name, get_original_col_name, ) -from vivarium.types import ScalarValue +from vivarium.types import ScalarMapper, VectorMapper class ResultsContext: @@ -55,7 +57,9 @@ def __init__(self) -> None: self.default_stratifications: list[str] = [] self.stratifications: list[Stratification] = [] self.excluded_categories: dict[str, list[str]] = {} - self.observations: defaultdict = defaultdict(lambda: defaultdict(list)) + self.observations: defaultdict[ + str, defaultdict[tuple[str, tuple[str, ...] | None], list[BaseObservation]] + ] = defaultdict(lambda: defaultdict(list)) @property def name(self) -> str: @@ -99,11 +103,7 @@ def add_stratification( sources: list[str], categories: list[str], excluded_categories: list[str] | None, - mapper: ( - Callable[[pd.Series | pd.DataFrame], pd.Series] - | Callable[[ScalarValue], str] - | None - ), + mapper: VectorMapper | ScalarMapper | None, is_vectorized: bool, ) -> None: """Add a stratification to the results context. @@ -190,7 +190,7 @@ def register_observation( name: str, pop_filter: str, when: str, - **kwargs, + **kwargs: Any, ) -> None: """Add an observation to the results context. @@ -301,6 +301,7 @@ def gather_results( if filtered_pop.empty: yield None, None, None else: + pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str] if stratification_names is None: pop = filtered_pop else: @@ -334,7 +335,7 @@ def _filter_population( @staticmethod def _get_groups( stratifications: tuple[str, ...], filtered_pop: pd.DataFrame - ) -> DataFrameGroupBy: + ) -> DataFrameGroupBy[tuple[str, ...] | str]: """Group the population by stratification. Notes @@ -355,7 +356,7 @@ def _get_groups( ) else: pop_groups = filtered_pop.groupby(lambda _: "all") - return pop_groups + return pop_groups # type: ignore[return-value] def _rename_stratification_columns(self, results: pd.DataFrame) -> None: """Convert the temporary stratified mapped index names back to their original names.""" diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index 6c18d279d..b9db068c9 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -8,12 +8,13 @@ from collections import defaultdict from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union import pandas as pd from vivarium.framework.event import Event from vivarium.framework.results.context import ResultsContext +from vivarium.framework.results.observation import BaseObservation from vivarium.framework.values import Pipeline from vivarium.manager import Manager from vivarium.types import ScalarValue @@ -301,7 +302,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: def register_observation( self, - observation_type, + observation_type: Type[BaseObservation], is_stratified: bool, name: str, pop_filter: str, diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index f400900d4..79a99e483 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -24,7 +24,7 @@ from abc import ABC from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING import pandas as pd from pandas.api.types import CategoricalDtype @@ -35,6 +35,9 @@ VALUE_COLUMN = "value" +if TYPE_CHECKING: + _PandasGroup = pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str] + @dataclass class BaseObservation(ABC): @@ -60,7 +63,8 @@ class BaseObservation(ABC): DataFrame or one with a complete set of stratifications as the index and all values set to 0.0.""" results_gatherer: Callable[ - [pd.DataFrame | DataFrameGroupBy[str], tuple[str, ...] | None], pd.DataFrame + [_PandasGroup, tuple[str, ...] | None], + pd.DataFrame, ] """Method or function that gathers the new observation results.""" results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] @@ -76,7 +80,7 @@ class BaseObservation(ABC): def observe( self, event: Event, - df: pd.DataFrame | DataFrameGroupBy[str], + df: _PandasGroup, stratifications: tuple[str, ...] | None, ) -> pd.DataFrame | None: """Determine whether to observe the given event, and if so, gather the results. @@ -139,7 +143,8 @@ def __init__( to_observe: Callable[[Event], bool] = lambda event: True, ): def _wrap_results_gatherer( - df: pd.DataFrame | DataFrameGroupBy[str], _: tuple[str, ...] | None + df: _PandasGroup, + _: tuple[str, ...] | None, ) -> pd.DataFrame: if isinstance(df, DataFrameGroupBy): raise TypeError( diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index 7be52813b..e0d4d1f7a 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -4,19 +4,20 @@ =============== """ + from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import pandas as pd from pandas.api.types import CategoricalDtype +from vivarium.types import ScalarMapper, VectorMapper + STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values" # TODO: Parameterizing pandas objects fails below python 3.12 -VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg] -ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg] @dataclass diff --git a/src/vivarium/types.py b/src/vivarium/types.py index 5d813e312..29f8af0ca 100644 --- a/src/vivarium/types.py +++ b/src/vivarium/types.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from datetime import datetime, timedelta from numbers import Number from typing import Union @@ -24,3 +25,6 @@ Timedelta = Union[pd.Timedelta, timedelta] ClockTime = Union[Time, int] ClockStepSize = Union[Timedelta, int] + +VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg] +ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg] From e440eb84a72b7546bc3ad3f78763dcf33a8f3a73 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Tue, 12 Nov 2024 16:43:30 -0800 Subject: [PATCH 3/5] fix --- src/vivarium/framework/results/observation.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index 79a99e483..8bf081b3b 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -24,7 +24,6 @@ from abc import ABC from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING import pandas as pd from pandas.api.types import CategoricalDtype @@ -35,9 +34,6 @@ VALUE_COLUMN = "value" -if TYPE_CHECKING: - _PandasGroup = pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str] - @dataclass class BaseObservation(ABC): @@ -63,7 +59,7 @@ class BaseObservation(ABC): DataFrame or one with a complete set of stratifications as the index and all values set to 0.0.""" results_gatherer: Callable[ - [_PandasGroup, tuple[str, ...] | None], + [pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], tuple[str, ...] | None], pd.DataFrame, ] """Method or function that gathers the new observation results.""" @@ -80,7 +76,7 @@ class BaseObservation(ABC): def observe( self, event: Event, - df: _PandasGroup, + df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], stratifications: tuple[str, ...] | None, ) -> pd.DataFrame | None: """Determine whether to observe the given event, and if so, gather the results. @@ -143,7 +139,7 @@ def __init__( to_observe: Callable[[Event], bool] = lambda event: True, ): def _wrap_results_gatherer( - df: _PandasGroup, + df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], _: tuple[str, ...] | None, ) -> pd.DataFrame: if isinstance(df, DataFrameGroupBy): From 8ce099f599842b704c8f5c80967a281c6c2fa959 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Wed, 13 Nov 2024 09:24:46 -0800 Subject: [PATCH 4/5] rename BaseObservation to Observation --- docs/source/concepts/results.rst | 26 +++++++++---------- src/vivarium/framework/results/context.py | 6 ++--- src/vivarium/framework/results/manager.py | 4 +-- src/vivarium/framework/results/observation.py | 8 +++--- src/vivarium/framework/results/observer.py | 2 +- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/docs/source/concepts/results.rst b/docs/source/concepts/results.rst index 3581aed59..4463a7fb8 100644 --- a/docs/source/concepts/results.rst +++ b/docs/source/concepts/results.rst @@ -303,7 +303,7 @@ A couple other more specific and commonly used observations are provided as well that gathers new results and concatenates them to any existing results. Ideally, all concrete classes should inherit from the -:class:`BaseObservation ` +:class:`Observation ` abstract base class, which contains the common attributes between observation types: .. list-table:: **Common Observation Attributes** @@ -312,40 +312,40 @@ abstract base class, which contains the common attributes between observation ty * - Attribute - Description - * - | :attr:`name ` + * - | :attr:`name ` - | Name of the observation. It will also be the name of the output results file | for this particular observation. - * - | :attr:`pop_filter ` + * - | :attr:`pop_filter ` - | A Pandas query filter string to filter the population down to the simulants | who should be considered for the observation. - * - | :attr:`when ` + * - | :attr:`when ` - | Name of the lifecycle phase the observation should happen. Valid values are: | "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - * - | :attr:`results_initializer ` + * - | :attr:`results_initializer ` - | Method or function that initializes the raw observation results | prior to starting the simulation. This could return, for example, an empty | DataFrame or one with a complete set of stratifications as the index and | all values set to 0.0. - * - | :attr:`results_gatherer ` + * - | :attr:`results_gatherer ` - | Method or function that gathers the new observation results. - * - | :attr:`results_updater ` + * - | :attr:`results_updater ` - | Method or function that updates existing raw observation results with newly | gathered results. - * - | :attr:`results_formatter ` + * - | :attr:`results_formatter ` - | Method or function that formats the raw observation results. - * - | :attr:`stratifications ` + * - | :attr:`stratifications ` - | Optional tuple of column names for the observation to stratify by. - * - | :attr:`to_observe ` + * - | :attr:`to_observe ` - | Method or function that determines whether to perform an observation on this Event. -The **BaseObservation** also contains the -:meth:`observe ` +The **Observation** also contains the +:meth:`observe ` method which is called at each :ref:`event ` and :ref:`time step ` to determine whether or not the observation should be recorded, and if so, gathers the results and stores them in the results system. .. note:: - All four observation types discussed above inherit from the **BaseObservation** + All four observation types discussed above inherit from the **Observation** abstract base class. What differentiates them are the assigned attributes (e.g. defining the **results_updater** to be an adding method for the **AddingObservation**) or adding other attributes as necessary (e.g. diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index cee07d481..119bebe0a 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -17,7 +17,7 @@ from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.results.exceptions import ResultsConfigurationError -from vivarium.framework.results.observation import BaseObservation +from vivarium.framework.results.observation import Observation from vivarium.framework.results.stratification import ( Stratification, get_mapped_col_name, @@ -58,7 +58,7 @@ def __init__(self) -> None: self.stratifications: list[Stratification] = [] self.excluded_categories: dict[str, list[str]] = {} self.observations: defaultdict[ - str, defaultdict[tuple[str, tuple[str, ...] | None], list[BaseObservation]] + str, defaultdict[tuple[str, tuple[str, ...] | None], list[Observation]] ] = defaultdict(lambda: defaultdict(list)) @property @@ -186,7 +186,7 @@ def add_stratification( def register_observation( self, - observation_type: Type[BaseObservation], + observation_type: Type[Observation], name: str, pop_filter: str, when: str, diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index b9db068c9..ee6864760 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -14,7 +14,7 @@ from vivarium.framework.event import Event from vivarium.framework.results.context import ResultsContext -from vivarium.framework.results.observation import BaseObservation +from vivarium.framework.results.observation import Observation from vivarium.framework.values import Pipeline from vivarium.manager import Manager from vivarium.types import ScalarValue @@ -302,7 +302,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: def register_observation( self, - observation_type: Type[BaseObservation], + observation_type: Type[Observation], is_stratified: bool, name: str, pop_filter: str, diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index 8bf081b3b..8037573f4 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -6,7 +6,7 @@ An observation is a class object that records simulation results; they are responsible for initializing, gathering, updating, and formatting results. -The provided :class:`BaseObservation` class is an abstract base class that should +The provided :class:`Observation` class is an abstract base class that should be subclassed by concrete observations. While there are no required abstract methods to define when subclassing, the class does provide common attributes as well as an `observe` method that determines whether to observe results for a given event. @@ -36,7 +36,7 @@ @dataclass -class BaseObservation(ABC): +class Observation(ABC): """An abstract base dataclass to be inherited by concrete observations. This class includes an :meth:`observe ` method that determines whether @@ -100,7 +100,7 @@ def observe( return self.results_gatherer(df, stratifications) -class UnstratifiedObservation(BaseObservation): +class UnstratifiedObservation(Observation): """Concrete class for observing results that are not stratified. The parent class `stratifications` are set to None and the `results_initializer` @@ -182,7 +182,7 @@ def create_empty_df( return pd.DataFrame() -class StratifiedObservation(BaseObservation): +class StratifiedObservation(Observation): """Concrete class for observing stratified results. The parent class `results_initializer` and `results_gatherer` methods are diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 7a37a5db6..a04c71b2b 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -5,7 +5,7 @@ ========= An observer is a component that is responsible for registering -:class:`observations ` +:class:`observations ` to the simulation. The provided :class:`Observer` class is an abstract base class that should be subclassed From 3b80c2438bc81266f99d5377b07f4ddf935f176c Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Wed, 13 Nov 2024 09:33:35 -0800 Subject: [PATCH 5/5] use 'type' built-in --- src/vivarium/framework/results/context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 119bebe0a..987c5777d 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Callable, Generator -from typing import Any, Type +from typing import Any import pandas as pd from pandas.core.groupby.generic import DataFrameGroupBy @@ -186,7 +186,7 @@ def add_stratification( def register_observation( self, - observation_type: Type[Observation], + observation_type: type[Observation], name: str, pop_filter: str, when: str,