From 25cd81533b682849f578381fa6c2faf86c9c9c2b Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Fri, 20 Sep 2024 15:28:04 -0600 Subject: [PATCH] mypy fix index_map.py (#482) --- CHANGELOG.rst | 4 + docs/nitpick-exceptions | 2 + pyproject.toml | 1 - .../framework/randomness/index_map.py | 79 ++++++++++--------- tests/framework/randomness/test_index_map.py | 17 ---- 5 files changed, 49 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bb8ae0164..3a2362430 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.0.6 - 09/20/24** + + - Fix mypy errors: vivarium/framework/randomness/index_map.py + **3.0.5 - 09/17/24** - Pin Sphinx below 8.0 diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index 11c95272a..1b3100e9d 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -7,6 +7,8 @@ py:class pandas.core.frame.DataFrame py:class pandas.core.series.Series py:class pandas.core.generic.PandasObject py:class pandas.core.groupby.generic.DataFrameGroupBy +py:class pd.Series +py:class pd.Index #scipy py:class scipy.stats._distn_infrastructure.rv_continuous diff --git a/pyproject.toml b/pyproject.toml index 8265e9a9e..b3c3a1510 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ exclude = [ 'src/vivarium/framework/plugins.py', 'src/vivarium/framework/population/manager.py', 'src/vivarium/framework/population/population_view.py', - 'src/vivarium/framework/randomness/index_map.py', 'src/vivarium/framework/randomness/manager.py', 'src/vivarium/framework/randomness/stream.py', 'src/vivarium/framework/resource.py', diff --git a/src/vivarium/framework/randomness/index_map.py b/src/vivarium/framework/randomness/index_map.py index 2f95431d1..9c1ee6f99 100644 --- a/src/vivarium/framework/randomness/index_map.py +++ b/src/vivarium/framework/randomness/index_map.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ =================== Randomness IndexMap @@ -11,11 +10,14 @@ """ -import datetime -from typing import List, Tuple, Union +from __future__ import annotations + +from datetime import datetime +from typing import Any import numpy as np import pandas as pd +import pandas.api.types as pdt from vivarium.framework.randomness.exceptions import RandomnessError @@ -26,10 +28,11 @@ class IndexMap: SIM_INDEX_COLUMN = "simulant_index" TEN_DIGIT_MODULUS = 10_000_000_000 - def __init__(self, key_columns: List[str] = None, size: int = 1_000_000): + def __init__(self, key_columns: list[str] | None = None, size: int = 1_000_000): self._use_crn = bool(key_columns) - self._key_columns = key_columns - self._map = None + self._key_columns = key_columns if key_columns else [] + self._map: pd.Series[int] | None = None + """The mapping between the key columns and the randomness index.""" self._size = size def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None: @@ -62,7 +65,7 @@ def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None: final_mapping = final_mapping.sort_index(level=self.SIM_INDEX_COLUMN) self._map = final_mapping - def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.MultiIndex]: + def _parse_new_keys(self, new_keys: pd.DataFrame) -> tuple[pd.Index[Any], pd.Index[Any]]: """Parses raw new keys into the mapping index. This returns a tuple of the new and final mapping indices. Both are pandas @@ -88,12 +91,12 @@ def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.Mul if self._map is None: final_mapping_index = new_mapping_index else: - final_mapping_index = self._map.index.append(new_mapping_index) + final_mapping_index = self._map.index.append(new_mapping_index) # type: ignore [no-untyped-call] return new_mapping_index, final_mapping_index def _build_final_mapping( - self, new_mapping_index: pd.Index, clock_time: pd.Timestamp - ) -> pd.Series: + self, new_mapping_index: pd.Index[Any], clock_time: pd.Timestamp + ) -> pd.Series[int]: """Builds a new mapping between key columns and the randomness index from the new mapping index and the existing map. @@ -123,9 +126,9 @@ def _build_final_mapping( def _resolve_collisions( self, - new_key_index: pd.MultiIndex, - current_mapping: pd.Series, - ) -> pd.Series: + new_key_index: pd.Index[Any], + current_mapping: pd.Series[int], + ) -> pd.Series[int]: """Resolves collisions in the new mapping by perturbing the hash. Parameters @@ -151,7 +154,7 @@ def _resolve_collisions( salt += 1 return current_mapping - def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: + def _hash(self, keys: pd.Index[Any], salt: int | pd.Timestamp = 0) -> pd.Series[int]: """Hashes the index into an integer index in the range [0, self.stride] Parameters @@ -159,7 +162,7 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: keys The new index to hash. salt - An integer used to perturb the hash in a deterministic way. Useful + Value used to perturb the hash in a deterministic way. Useful in dealing with collisions. Returns @@ -170,9 +173,9 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: """ key_frame = keys.to_frame() new_map = pd.Series(0, index=keys) - salt = self._convert_to_ten_digit_int(pd.Series(salt, index=keys)) + salt_series = self._convert_to_ten_digit_int(pd.Series(salt, index=keys)) - for i, column_name in enumerate(key_frame.columns): + for _i, column_name in enumerate(key_frame.columns): column = self._convert_to_ten_digit_int(key_frame[column_name]) primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 27] @@ -183,11 +186,13 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: # our map size the amount of additional periodicity this # introduces is pretty trivial. out *= np.power(p, self._digit(column, idx)) - new_map += out + salt + new_map += out + salt_series return new_map % len(self) - def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series: + def _convert_to_ten_digit_int( + self, column: pd.Series[datetime | int | float] + ) -> pd.Series[int]: """Converts a column of datetimes, integers, or floats into a column of 10 digit integers. @@ -206,47 +211,49 @@ def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series: If the column contains data that is neither a datetime-like nor numeric. """ - if isinstance(column.iloc[0], datetime.datetime): - column = self._clip_to_seconds(column.astype(np.int64)) - elif np.issubdtype(column.iloc[0], np.integer): + if pdt.is_datetime64_any_dtype(column): + integers = self._clip_to_seconds(column.astype(np.int64)) + elif pdt.is_integer_dtype(column): if not len(column >= 0) == len(column): raise RandomnessError( "Values in integer columns must be greater than or equal to zero." ) - column = self._spread(column) - elif np.issubdtype(column.iloc[0], np.floating): - column = self._shift(column) + integers = self._spread(column.astype(int)) + elif pdt.is_float_dtype(column): + integers = self._shift(column.astype(float)) else: raise RandomnessError( f"Unhashable column type {type(column.iloc[0])}. " "IndexMap accepts datetime like columns and numeric columns." ) - return column + return integers @staticmethod - def _digit(m: Union[int, pd.Series], n: int) -> Union[int, pd.Series]: + def _digit(m: pd.Series[int], n: int) -> pd.Series[int]: """Returns the nth digit of each number in m.""" - return (m // (10**n)) % 10 + nth_digits: pd.Series[int] = (m // (10**n)) % 10 + return nth_digits @staticmethod - def _clip_to_seconds(m: Union[int, pd.Series]) -> Union[int, pd.Series]: + def _clip_to_seconds(m: pd.Series[int]) -> pd.Series[int]: """Clips UTC datetime in nanoseconds to seconds.""" return m // pd.Timedelta(1, unit="s").value - def _spread(self, m: Union[int, pd.Series]) -> Union[int, pd.Series]: + def _spread(self, m: pd.Series[int]) -> pd.Series[int]: """Spreads out integer values to give smaller values more weight.""" return (m * 111_111) % self.TEN_DIGIT_MODULUS - def _shift(self, m: Union[float, pd.Series]) -> Union[int, pd.Series]: + def _shift(self, m: pd.Series[float]) -> pd.Series[int]: """Shifts floats so that the first 10 decimal digits are significant.""" out = m % 1 * self.TEN_DIGIT_MODULUS // 1 - if isinstance(out, pd.Series): - return out.astype("int64") - return int(out) + return out.astype("int64") - def __getitem__(self, index: pd.Index) -> np.ndarray: + def __getitem__(self, index: pd.Index[int]) -> np.ndarray[int, Any]: if self._use_crn: - return self._map.loc[index].values + if self._map is None: + raise RandomnessError("IndexMap is empty") + else: + return self._map.loc[index].to_numpy() else: return index.values diff --git a/tests/framework/randomness/test_index_map.py b/tests/framework/randomness/test_index_map.py index 8e7e9e5e9..67c1e33b4 100644 --- a/tests/framework/randomness/test_index_map.py +++ b/tests/framework/randomness/test_index_map.py @@ -61,13 +61,6 @@ def map_size_and_hashed_values(request): return len(m), m._hash(keys) -def test_digit_scalar(): - m = IndexMap() - k = 123456789 - for i in range(10): - assert m._digit(k, i) == 10 - (i + 1) - - def test_digit_series(): m = IndexMap() k = pd.Series(123456789, index=range(10000)) @@ -94,11 +87,6 @@ def test_clip_to_seconds_series(): assert m._clip_to_seconds(k).unique()[0] == stamp -def test_spread_scalar(): - m = IndexMap() - assert m._spread(1234567890) == 4072825790 - - def test_spread_series(): m = IndexMap() s = pd.Series(1234567890, index=range(10000)) @@ -106,11 +94,6 @@ def test_spread_series(): assert m._spread(s).unique()[0] == 4072825790 -def test_shift_scalar(): - m = IndexMap() - assert m._shift(1.1234567890) == 1234567890 - - def test_shift_series(): m = IndexMap() s = pd.Series(1.1234567890, index=range(10000))