Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Sbachmei/mic 5549/mypy results context #538

Merged
merged 6 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.1.1 - TBD/TBD/TBD**

- Fix mypy errors in vivarium/framework/results/context.py

**3.1.0 - 11/07/24**

- Drop support for python 3.9
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ exclude = [
'src/vivarium/framework/lookup/manager.py',
'src/vivarium/framework/population/manager.py',
'src/vivarium/framework/population/population_view.py',
'src/vivarium/framework/results/context.py',
'src/vivarium/framework/results/interface.py',
'src/vivarium/framework/results/manager.py',
'src/vivarium/framework/results/observer.py',
Expand Down
50 changes: 25 additions & 25 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# mypy: ignore-errors
"""
===============
Results Context
===============

"""

from __future__ import annotations

from collections import defaultdict
from typing import Callable, Generator, List, Optional, Tuple, Type, Union
from collections.abc import Callable, Generator
from typing import Any, Type

import pandas as pd
from pandas.core.groupby.generic import DataFrameGroupBy
Expand All @@ -21,7 +23,7 @@
get_mapped_col_name,
get_original_col_name,
)
from vivarium.types import ScalarValue
from vivarium.types import ScalarMapper, VectorMapper


class ResultsContext:
Expand Down Expand Up @@ -52,10 +54,12 @@ class ResultsContext:
"""

def __init__(self) -> None:
self.default_stratifications: List[str] = []
self.stratifications: List[Stratification] = []
self.default_stratifications: list[str] = []
self.stratifications: list[Stratification] = []
self.excluded_categories: dict[str, list[str]] = {}
self.observations: defaultdict = defaultdict(lambda: defaultdict(list))
self.observations: defaultdict[
str, defaultdict[tuple[str, tuple[str, ...] | None], list[BaseObservation]]
] = defaultdict(lambda: defaultdict(list))

@property
def name(self) -> str:
Expand All @@ -73,7 +77,7 @@ def setup(self, builder: Builder) -> None:
)

# noinspection PyAttributeOutsideInit
def set_default_stratifications(self, default_grouping_columns: List[str]) -> None:
def set_default_stratifications(self, default_grouping_columns: list[str]) -> None:
"""Set the default stratifications to be used by stratified observations.

Parameters
Expand All @@ -96,15 +100,10 @@ def set_default_stratifications(self, default_grouping_columns: List[str]) -> No
def add_stratification(
self,
name: str,
sources: List[str],
categories: List[str],
excluded_categories: Optional[List[str]],
mapper: Optional[
Union[
Callable[[Union[pd.Series, pd.DataFrame]], pd.Series],
Callable[[ScalarValue], str],
]
],
sources: list[str],
categories: list[str],
excluded_categories: list[str] | None,
mapper: VectorMapper | ScalarMapper | None,
is_vectorized: bool,
) -> None:
"""Add a stratification to the results context.
Expand Down Expand Up @@ -191,7 +190,7 @@ def register_observation(
name: str,
pop_filter: str,
when: str,
**kwargs,
**kwargs: Any,
) -> None:
"""Add an observation to the results context.

Expand Down Expand Up @@ -242,10 +241,10 @@ def register_observation(
def gather_results(
self, population: pd.DataFrame, lifecycle_phase: str, event: Event
) -> Generator[
Tuple[
Optional[pd.DataFrame],
Optional[str],
Optional[Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]],
tuple[
pd.DataFrame | None,
str | None,
Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] | None,
],
None,
None,
Expand Down Expand Up @@ -302,6 +301,7 @@ def gather_results(
if filtered_pop.empty:
yield None, None, None
else:
pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str]
if stratification_names is None:
pop = filtered_pop
else:
Expand All @@ -317,7 +317,7 @@ def _filter_population(
self,
population: pd.DataFrame,
pop_filter: str,
stratification_names: Optional[tuple[str, ...]],
stratification_names: tuple[str, ...] | None,
) -> pd.DataFrame:
"""Filter out simulants not to observe."""
pop = population.query(pop_filter) if pop_filter else population.copy()
Expand All @@ -334,8 +334,8 @@ def _filter_population(

@staticmethod
def _get_groups(
stratifications: Tuple[str, ...], filtered_pop: pd.DataFrame
) -> DataFrameGroupBy:
stratifications: tuple[str, ...], filtered_pop: pd.DataFrame
) -> DataFrameGroupBy[tuple[str, ...] | str]:
"""Group the population by stratification.

Notes
Expand All @@ -356,7 +356,7 @@ def _get_groups(
)
else:
pop_groups = filtered_pop.groupby(lambda _: "all")
return pop_groups
return pop_groups # type: ignore[return-value]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to talk about this. I think it's a discrepancy between mypy requiring DataFrameGroupBys to be generic and python RuntimeErroring w/ DataFrameGroupBy is not subscriptable. from __future__ import annotations does not fix like it does for pd.Series.

def _rename_stratification_columns(self, results: pd.DataFrame) -> None:
"""Convert the temporary stratified mapped index names back to their original names."""
Expand Down
5 changes: 3 additions & 2 deletions src/vivarium/framework/results/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union

import pandas as pd

from vivarium.framework.event import Event
from vivarium.framework.results.context import ResultsContext
from vivarium.framework.results.observation import BaseObservation
from vivarium.framework.values import Pipeline
from vivarium.manager import Manager
from vivarium.types import ScalarValue
Expand Down Expand Up @@ -301,7 +302,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series:

def register_observation(
self,
observation_type,
observation_type: Type[BaseObservation],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't causing an error but it was just wrong. At this point we're sending in the class itself, NOT an instantiation of it (actually, it might be worth just saying observation_type: Type[StratifiedObservation | UnstratifiedObservation | AddingObservation | ConcatenatingObservation] if folks prefer b/c that's what it really is. But they all inherit from BaseObservation so this is less typing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should BaseObservation be called Observation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had it called that at once and changed it to this but now can't recall why. It does subclass from abc, though there are no abstract methods or anything on it. Probably doesn't matter?

Copy link
Contributor

@patricktnast patricktnast Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type[C] is covariant, so I think putting in the base observation makes sense. It looks like there is a builtin type you could use instead, though

is_stratified: bool,
name: str,
pop_filter: str,
Expand Down
13 changes: 9 additions & 4 deletions src/vivarium/framework/results/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from abc import ABC
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING

import pandas as pd
from pandas.api.types import CategoricalDtype
Expand All @@ -35,6 +35,9 @@

VALUE_COLUMN = "value"

if TYPE_CHECKING:
_PandasGroup = pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str]


@dataclass
class BaseObservation(ABC):
Expand All @@ -60,7 +63,8 @@ class BaseObservation(ABC):
DataFrame or one with a complete set of stratifications as the index and
all values set to 0.0."""
results_gatherer: Callable[
[pd.DataFrame | DataFrameGroupBy[str], tuple[str, ...] | None], pd.DataFrame
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming the things in this file were just type-hinted incorrectly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, though I'm not gonna lie - I'm having trouble understanding how to type-hint DataFrameGroupBys...

[_PandasGroup, tuple[str, ...] | None],
pd.DataFrame,
]
"""Method or function that gathers the new observation results."""
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]
Expand All @@ -76,7 +80,7 @@ class BaseObservation(ABC):
def observe(
self,
event: Event,
df: pd.DataFrame | DataFrameGroupBy[str],
df: _PandasGroup,
stratifications: tuple[str, ...] | None,
) -> pd.DataFrame | None:
"""Determine whether to observe the given event, and if so, gather the results.
Expand Down Expand Up @@ -139,7 +143,8 @@ def __init__(
to_observe: Callable[[Event], bool] = lambda event: True,
):
def _wrap_results_gatherer(
df: pd.DataFrame | DataFrameGroupBy[str], _: tuple[str, ...] | None
df: _PandasGroup,
_: tuple[str, ...] | None,
) -> pd.DataFrame:
if isinstance(df, DataFrameGroupBy):
raise TypeError(
Expand Down
7 changes: 4 additions & 3 deletions src/vivarium/framework/results/stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
===============

"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable
from typing import Any

import pandas as pd
from pandas.api.types import CategoricalDtype

from vivarium.types import ScalarMapper, VectorMapper

STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values"

# TODO: Parameterizing pandas objects fails below python 3.12
VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg]
ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions src/vivarium/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Callable
from datetime import datetime, timedelta
from numbers import Number
from typing import Union
Expand All @@ -24,3 +25,6 @@
Timedelta = Union[pd.Timedelta, timedelta]
ClockTime = Union[Time, int]
ClockStepSize = Union[Timedelta, int]

VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg]
ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]
Loading