From 902f115d017d3572c567f134eca817aabb367c89 Mon Sep 17 00:00:00 2001 From: sbachmei Date: Thu, 19 Sep 2024 14:11:38 -0600 Subject: [PATCH 1/3] mypy fix index_map.py --- CHANGELOG.rst | 4 +- docs/nitpick-exceptions | 2 + pyproject.toml | 1 - .../framework/randomness/index_map.py | 78 ++++++++++--------- tests/framework/randomness/test_index_map.py | 17 ---- 5 files changed, 47 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bb8ae0164..d21afe1fb 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,6 @@ -**3.0.5 - 09/17/24** +**3.0.5 - 09/20/24** - - Pin Sphinx below 8.0 + - Fix mypy errors: vivarium/framework/randomness/index_map.py **3.0.4 - 09/12/24** diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index 11c95272a..1b3100e9d 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -7,6 +7,8 @@ py:class pandas.core.frame.DataFrame py:class pandas.core.series.Series py:class pandas.core.generic.PandasObject py:class pandas.core.groupby.generic.DataFrameGroupBy +py:class pd.Series +py:class pd.Index #scipy py:class scipy.stats._distn_infrastructure.rv_continuous diff --git a/pyproject.toml b/pyproject.toml index 8265e9a9e..b3c3a1510 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ exclude = [ 'src/vivarium/framework/plugins.py', 'src/vivarium/framework/population/manager.py', 'src/vivarium/framework/population/population_view.py', - 'src/vivarium/framework/randomness/index_map.py', 'src/vivarium/framework/randomness/manager.py', 'src/vivarium/framework/randomness/stream.py', 'src/vivarium/framework/resource.py', diff --git a/src/vivarium/framework/randomness/index_map.py b/src/vivarium/framework/randomness/index_map.py index 2f95431d1..267717016 100644 --- a/src/vivarium/framework/randomness/index_map.py +++ b/src/vivarium/framework/randomness/index_map.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ =================== Randomness IndexMap @@ -11,8 +10,10 @@ """ +from __future__ import annotations + import datetime -from typing import List, Tuple, Union +from typing import Any, Optional, Union import numpy as np import pandas as pd @@ -26,10 +27,11 @@ class IndexMap: SIM_INDEX_COLUMN = "simulant_index" TEN_DIGIT_MODULUS = 10_000_000_000 - def __init__(self, key_columns: List[str] = None, size: int = 1_000_000): + def __init__(self, key_columns: Optional[list[str]] = None, size: int = 1_000_000): self._use_crn = bool(key_columns) - self._key_columns = key_columns - self._map = None + self._key_columns = [] if key_columns is None else key_columns + self._map: Optional[pd.Series[int]] = None + """The mapping between the key columns and the randomness index.""" self._size = size def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None: @@ -62,7 +64,7 @@ def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None: final_mapping = final_mapping.sort_index(level=self.SIM_INDEX_COLUMN) self._map = final_mapping - def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.MultiIndex]: + def _parse_new_keys(self, new_keys: pd.DataFrame) -> tuple[pd.Index[Any], pd.Index[Any]]: """Parses raw new keys into the mapping index. This returns a tuple of the new and final mapping indices. Both are pandas @@ -88,12 +90,12 @@ def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.Mul if self._map is None: final_mapping_index = new_mapping_index else: - final_mapping_index = self._map.index.append(new_mapping_index) + final_mapping_index = self._map.index.append(new_mapping_index) # type: ignore [no-untyped-call] return new_mapping_index, final_mapping_index def _build_final_mapping( - self, new_mapping_index: pd.Index, clock_time: pd.Timestamp - ) -> pd.Series: + self, new_mapping_index: pd.Index[Any], clock_time: pd.Timestamp + ) -> pd.Series[int]: """Builds a new mapping between key columns and the randomness index from the new mapping index and the existing map. @@ -123,9 +125,9 @@ def _build_final_mapping( def _resolve_collisions( self, - new_key_index: pd.MultiIndex, - current_mapping: pd.Series, - ) -> pd.Series: + new_key_index: pd.Index[Any], + current_mapping: pd.Series[int], + ) -> pd.Series[int]: """Resolves collisions in the new mapping by perturbing the hash. Parameters @@ -151,7 +153,9 @@ def _resolve_collisions( salt += 1 return current_mapping - def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: + def _hash( + self, keys: pd.Index[Any], salt: Union[int, pd.Timestamp] = 0 + ) -> pd.Series[int]: """Hashes the index into an integer index in the range [0, self.stride] Parameters @@ -159,7 +163,7 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: keys The new index to hash. salt - An integer used to perturb the hash in a deterministic way. Useful + Value used to perturb the hash in a deterministic way. Useful in dealing with collisions. Returns @@ -170,9 +174,9 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: """ key_frame = keys.to_frame() new_map = pd.Series(0, index=keys) - salt = self._convert_to_ten_digit_int(pd.Series(salt, index=keys)) + salt_series = self._convert_to_ten_digit_int(pd.Series(salt, index=keys)) - for i, column_name in enumerate(key_frame.columns): + for _i, column_name in enumerate(key_frame.columns): column = self._convert_to_ten_digit_int(key_frame[column_name]) primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 27] @@ -183,11 +187,13 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: # our map size the amount of additional periodicity this # introduces is pretty trivial. out *= np.power(p, self._digit(column, idx)) - new_map += out + salt + new_map += out + salt_series return new_map % len(self) - def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series: + def _convert_to_ten_digit_int( + self, column: pd.Series[Union[datetime.datetime, int, float]] + ) -> pd.Series[int]: """Converts a column of datetimes, integers, or floats into a column of 10 digit integers. @@ -206,47 +212,49 @@ def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series: If the column contains data that is neither a datetime-like nor numeric. """ - if isinstance(column.iloc[0], datetime.datetime): - column = self._clip_to_seconds(column.astype(np.int64)) - elif np.issubdtype(column.iloc[0], np.integer): + if pd.api.types.is_datetime64_any_dtype(column): + integers = self._clip_to_seconds(column.astype(np.int64)) + elif pd.api.types.is_integer_dtype(column): if not len(column >= 0) == len(column): raise RandomnessError( "Values in integer columns must be greater than or equal to zero." ) - column = self._spread(column) - elif np.issubdtype(column.iloc[0], np.floating): - column = self._shift(column) + integers = self._spread(column.astype(int)) + elif pd.api.types.is_float_dtype(column): + integers = self._shift(column.astype(float)) else: raise RandomnessError( f"Unhashable column type {type(column.iloc[0])}. " "IndexMap accepts datetime like columns and numeric columns." ) - return column + return integers @staticmethod - def _digit(m: Union[int, pd.Series], n: int) -> Union[int, pd.Series]: + def _digit(m: pd.Series[int], n: int) -> pd.Series[int]: """Returns the nth digit of each number in m.""" - return (m // (10**n)) % 10 + nth_digits: pd.Series[int] = (m // (10**n)) % 10 + return nth_digits @staticmethod - def _clip_to_seconds(m: Union[int, pd.Series]) -> Union[int, pd.Series]: + def _clip_to_seconds(m: pd.Series[int]) -> pd.Series[int]: """Clips UTC datetime in nanoseconds to seconds.""" return m // pd.Timedelta(1, unit="s").value - def _spread(self, m: Union[int, pd.Series]) -> Union[int, pd.Series]: + def _spread(self, m: pd.Series[int]) -> pd.Series[int]: """Spreads out integer values to give smaller values more weight.""" return (m * 111_111) % self.TEN_DIGIT_MODULUS - def _shift(self, m: Union[float, pd.Series]) -> Union[int, pd.Series]: + def _shift(self, m: pd.Series[float]) -> pd.Series[int]: """Shifts floats so that the first 10 decimal digits are significant.""" out = m % 1 * self.TEN_DIGIT_MODULUS // 1 - if isinstance(out, pd.Series): - return out.astype("int64") - return int(out) + return out.astype("int64") - def __getitem__(self, index: pd.Index) -> np.ndarray: + def __getitem__(self, index: pd.Index[int]) -> np.ndarray[int, Any]: if self._use_crn: - return self._map.loc[index].values + if self._map is None: + raise RandomnessError("IndexMap is empty") + else: + return self._map.loc[index].to_numpy() else: return index.values diff --git a/tests/framework/randomness/test_index_map.py b/tests/framework/randomness/test_index_map.py index 8e7e9e5e9..67c1e33b4 100644 --- a/tests/framework/randomness/test_index_map.py +++ b/tests/framework/randomness/test_index_map.py @@ -61,13 +61,6 @@ def map_size_and_hashed_values(request): return len(m), m._hash(keys) -def test_digit_scalar(): - m = IndexMap() - k = 123456789 - for i in range(10): - assert m._digit(k, i) == 10 - (i + 1) - - def test_digit_series(): m = IndexMap() k = pd.Series(123456789, index=range(10000)) @@ -94,11 +87,6 @@ def test_clip_to_seconds_series(): assert m._clip_to_seconds(k).unique()[0] == stamp -def test_spread_scalar(): - m = IndexMap() - assert m._spread(1234567890) == 4072825790 - - def test_spread_series(): m = IndexMap() s = pd.Series(1234567890, index=range(10000)) @@ -106,11 +94,6 @@ def test_spread_series(): assert m._spread(s).unique()[0] == 4072825790 -def test_shift_scalar(): - m = IndexMap() - assert m._shift(1.1234567890) == 1234567890 - - def test_shift_series(): m = IndexMap() s = pd.Series(1.1234567890, index=range(10000)) From 8d4fd7c9635e6d0d90123b7ef6772ca138e6c0ca Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Thu, 19 Sep 2024 15:32:11 -0700 Subject: [PATCH 2/3] review updates --- CHANGELOG.rst | 6 ++++- .../framework/randomness/index_map.py | 23 +++++++++---------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d21afe1fb..3a2362430 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,7 +1,11 @@ -**3.0.5 - 09/20/24** +**3.0.6 - 09/20/24** - Fix mypy errors: vivarium/framework/randomness/index_map.py +**3.0.5 - 09/17/24** + + - Pin Sphinx below 8.0 + **3.0.4 - 09/12/24** - Introduce static type checking with mypy diff --git a/src/vivarium/framework/randomness/index_map.py b/src/vivarium/framework/randomness/index_map.py index 267717016..dcf9dcafd 100644 --- a/src/vivarium/framework/randomness/index_map.py +++ b/src/vivarium/framework/randomness/index_map.py @@ -12,11 +12,12 @@ from __future__ import annotations -import datetime -from typing import Any, Optional, Union +from datetime import datetime +from typing import Any import numpy as np import pandas as pd +from pandas.api import types as pdtypes from vivarium.framework.randomness.exceptions import RandomnessError @@ -27,10 +28,10 @@ class IndexMap: SIM_INDEX_COLUMN = "simulant_index" TEN_DIGIT_MODULUS = 10_000_000_000 - def __init__(self, key_columns: Optional[list[str]] = None, size: int = 1_000_000): + def __init__(self, key_columns: list[str] | None = None, size: int = 1_000_000): self._use_crn = bool(key_columns) - self._key_columns = [] if key_columns is None else key_columns - self._map: Optional[pd.Series[int]] = None + self._key_columns = key_columns if key_columns else [] + self._map: pd.Series[int] | None = None """The mapping between the key columns and the randomness index.""" self._size = size @@ -153,9 +154,7 @@ def _resolve_collisions( salt += 1 return current_mapping - def _hash( - self, keys: pd.Index[Any], salt: Union[int, pd.Timestamp] = 0 - ) -> pd.Series[int]: + def _hash(self, keys: pd.Index[Any], salt: int | pd.Timestamp = 0) -> pd.Series[int]: """Hashes the index into an integer index in the range [0, self.stride] Parameters @@ -192,7 +191,7 @@ def _hash( return new_map % len(self) def _convert_to_ten_digit_int( - self, column: pd.Series[Union[datetime.datetime, int, float]] + self, column: pd.Series[datetime | int | float] ) -> pd.Series[int]: """Converts a column of datetimes, integers, or floats into a column of 10 digit integers. @@ -212,15 +211,15 @@ def _convert_to_ten_digit_int( If the column contains data that is neither a datetime-like nor numeric. """ - if pd.api.types.is_datetime64_any_dtype(column): + if pdtypes.is_datetime64_any_dtype(column): integers = self._clip_to_seconds(column.astype(np.int64)) - elif pd.api.types.is_integer_dtype(column): + elif pdtypes.is_integer_dtype(column): if not len(column >= 0) == len(column): raise RandomnessError( "Values in integer columns must be greater than or equal to zero." ) integers = self._spread(column.astype(int)) - elif pd.api.types.is_float_dtype(column): + elif pdtypes.is_float_dtype(column): integers = self._shift(column.astype(float)) else: raise RandomnessError( From 4e5f7c42630d9ad0d3cdbeeb625fdd4610153985 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Fri, 20 Sep 2024 08:32:04 -0700 Subject: [PATCH 3/3] change imported alias --- src/vivarium/framework/randomness/index_map.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/vivarium/framework/randomness/index_map.py b/src/vivarium/framework/randomness/index_map.py index dcf9dcafd..9c1ee6f99 100644 --- a/src/vivarium/framework/randomness/index_map.py +++ b/src/vivarium/framework/randomness/index_map.py @@ -17,7 +17,7 @@ import numpy as np import pandas as pd -from pandas.api import types as pdtypes +import pandas.api.types as pdt from vivarium.framework.randomness.exceptions import RandomnessError @@ -211,15 +211,15 @@ def _convert_to_ten_digit_int( If the column contains data that is neither a datetime-like nor numeric. """ - if pdtypes.is_datetime64_any_dtype(column): + if pdt.is_datetime64_any_dtype(column): integers = self._clip_to_seconds(column.astype(np.int64)) - elif pdtypes.is_integer_dtype(column): + elif pdt.is_integer_dtype(column): if not len(column >= 0) == len(column): raise RandomnessError( "Values in integer columns must be greater than or equal to zero." ) integers = self._spread(column.astype(int)) - elif pdtypes.is_float_dtype(column): + elif pdt.is_float_dtype(column): integers = self._shift(column.astype(float)) else: raise RandomnessError(