Skip to content

Commit

Permalink
mypy fix index_map.py (#482)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier authored Sep 20, 2024
1 parent 82f2868 commit 25cd815
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 54 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.0.6 - 09/20/24**

- Fix mypy errors: vivarium/framework/randomness/index_map.py

**3.0.5 - 09/17/24**

- Pin Sphinx below 8.0
Expand Down
2 changes: 2 additions & 0 deletions docs/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ py:class pandas.core.frame.DataFrame
py:class pandas.core.series.Series
py:class pandas.core.generic.PandasObject
py:class pandas.core.groupby.generic.DataFrameGroupBy
py:class pd.Series
py:class pd.Index

#scipy
py:class scipy.stats._distn_infrastructure.rv_continuous
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ exclude = [
'src/vivarium/framework/plugins.py',
'src/vivarium/framework/population/manager.py',
'src/vivarium/framework/population/population_view.py',
'src/vivarium/framework/randomness/index_map.py',
'src/vivarium/framework/randomness/manager.py',
'src/vivarium/framework/randomness/stream.py',
'src/vivarium/framework/resource.py',
Expand Down
79 changes: 43 additions & 36 deletions src/vivarium/framework/randomness/index_map.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
"""
===================
Randomness IndexMap
Expand All @@ -11,11 +10,14 @@
"""

import datetime
from typing import List, Tuple, Union
from __future__ import annotations

from datetime import datetime
from typing import Any

import numpy as np
import pandas as pd
import pandas.api.types as pdt

from vivarium.framework.randomness.exceptions import RandomnessError

Expand All @@ -26,10 +28,11 @@ class IndexMap:
SIM_INDEX_COLUMN = "simulant_index"
TEN_DIGIT_MODULUS = 10_000_000_000

def __init__(self, key_columns: List[str] = None, size: int = 1_000_000):
def __init__(self, key_columns: list[str] | None = None, size: int = 1_000_000):
self._use_crn = bool(key_columns)
self._key_columns = key_columns
self._map = None
self._key_columns = key_columns if key_columns else []
self._map: pd.Series[int] | None = None
"""The mapping between the key columns and the randomness index."""
self._size = size

def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None:
Expand Down Expand Up @@ -62,7 +65,7 @@ def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None:
final_mapping = final_mapping.sort_index(level=self.SIM_INDEX_COLUMN)
self._map = final_mapping

def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.MultiIndex]:
def _parse_new_keys(self, new_keys: pd.DataFrame) -> tuple[pd.Index[Any], pd.Index[Any]]:
"""Parses raw new keys into the mapping index.
This returns a tuple of the new and final mapping indices. Both are pandas
Expand All @@ -88,12 +91,12 @@ def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.Mul
if self._map is None:
final_mapping_index = new_mapping_index
else:
final_mapping_index = self._map.index.append(new_mapping_index)
final_mapping_index = self._map.index.append(new_mapping_index) # type: ignore [no-untyped-call]
return new_mapping_index, final_mapping_index

def _build_final_mapping(
self, new_mapping_index: pd.Index, clock_time: pd.Timestamp
) -> pd.Series:
self, new_mapping_index: pd.Index[Any], clock_time: pd.Timestamp
) -> pd.Series[int]:
"""Builds a new mapping between key columns and the randomness index from the
new mapping index and the existing map.
Expand Down Expand Up @@ -123,9 +126,9 @@ def _build_final_mapping(

def _resolve_collisions(
self,
new_key_index: pd.MultiIndex,
current_mapping: pd.Series,
) -> pd.Series:
new_key_index: pd.Index[Any],
current_mapping: pd.Series[int],
) -> pd.Series[int]:
"""Resolves collisions in the new mapping by perturbing the hash.
Parameters
Expand All @@ -151,15 +154,15 @@ def _resolve_collisions(
salt += 1
return current_mapping

def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series:
def _hash(self, keys: pd.Index[Any], salt: int | pd.Timestamp = 0) -> pd.Series[int]:
"""Hashes the index into an integer index in the range [0, self.stride]
Parameters
----------
keys
The new index to hash.
salt
An integer used to perturb the hash in a deterministic way. Useful
Value used to perturb the hash in a deterministic way. Useful
in dealing with collisions.
Returns
Expand All @@ -170,9 +173,9 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series:
"""
key_frame = keys.to_frame()
new_map = pd.Series(0, index=keys)
salt = self._convert_to_ten_digit_int(pd.Series(salt, index=keys))
salt_series = self._convert_to_ten_digit_int(pd.Series(salt, index=keys))

for i, column_name in enumerate(key_frame.columns):
for _i, column_name in enumerate(key_frame.columns):
column = self._convert_to_ten_digit_int(key_frame[column_name])

primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 27]
Expand All @@ -183,11 +186,13 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series:
# our map size the amount of additional periodicity this
# introduces is pretty trivial.
out *= np.power(p, self._digit(column, idx))
new_map += out + salt
new_map += out + salt_series

return new_map % len(self)

def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series:
def _convert_to_ten_digit_int(
self, column: pd.Series[datetime | int | float]
) -> pd.Series[int]:
"""Converts a column of datetimes, integers, or floats into a column
of 10 digit integers.
Expand All @@ -206,47 +211,49 @@ def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series:
If the column contains data that is neither a datetime-like nor
numeric.
"""
if isinstance(column.iloc[0], datetime.datetime):
column = self._clip_to_seconds(column.astype(np.int64))
elif np.issubdtype(column.iloc[0], np.integer):
if pdt.is_datetime64_any_dtype(column):
integers = self._clip_to_seconds(column.astype(np.int64))
elif pdt.is_integer_dtype(column):
if not len(column >= 0) == len(column):
raise RandomnessError(
"Values in integer columns must be greater than or equal to zero."
)
column = self._spread(column)
elif np.issubdtype(column.iloc[0], np.floating):
column = self._shift(column)
integers = self._spread(column.astype(int))
elif pdt.is_float_dtype(column):
integers = self._shift(column.astype(float))
else:
raise RandomnessError(
f"Unhashable column type {type(column.iloc[0])}. "
"IndexMap accepts datetime like columns and numeric columns."
)
return column
return integers

@staticmethod
def _digit(m: Union[int, pd.Series], n: int) -> Union[int, pd.Series]:
def _digit(m: pd.Series[int], n: int) -> pd.Series[int]:
"""Returns the nth digit of each number in m."""
return (m // (10**n)) % 10
nth_digits: pd.Series[int] = (m // (10**n)) % 10
return nth_digits

@staticmethod
def _clip_to_seconds(m: Union[int, pd.Series]) -> Union[int, pd.Series]:
def _clip_to_seconds(m: pd.Series[int]) -> pd.Series[int]:
"""Clips UTC datetime in nanoseconds to seconds."""
return m // pd.Timedelta(1, unit="s").value

def _spread(self, m: Union[int, pd.Series]) -> Union[int, pd.Series]:
def _spread(self, m: pd.Series[int]) -> pd.Series[int]:
"""Spreads out integer values to give smaller values more weight."""
return (m * 111_111) % self.TEN_DIGIT_MODULUS

def _shift(self, m: Union[float, pd.Series]) -> Union[int, pd.Series]:
def _shift(self, m: pd.Series[float]) -> pd.Series[int]:
"""Shifts floats so that the first 10 decimal digits are significant."""
out = m % 1 * self.TEN_DIGIT_MODULUS // 1
if isinstance(out, pd.Series):
return out.astype("int64")
return int(out)
return out.astype("int64")

def __getitem__(self, index: pd.Index) -> np.ndarray:
def __getitem__(self, index: pd.Index[int]) -> np.ndarray[int, Any]:
if self._use_crn:
return self._map.loc[index].values
if self._map is None:
raise RandomnessError("IndexMap is empty")
else:
return self._map.loc[index].to_numpy()
else:
return index.values

Expand Down
17 changes: 0 additions & 17 deletions tests/framework/randomness/test_index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ def map_size_and_hashed_values(request):
return len(m), m._hash(keys)


def test_digit_scalar():
m = IndexMap()
k = 123456789
for i in range(10):
assert m._digit(k, i) == 10 - (i + 1)


def test_digit_series():
m = IndexMap()
k = pd.Series(123456789, index=range(10000))
Expand All @@ -94,23 +87,13 @@ def test_clip_to_seconds_series():
assert m._clip_to_seconds(k).unique()[0] == stamp


def test_spread_scalar():
m = IndexMap()
assert m._spread(1234567890) == 4072825790


def test_spread_series():
m = IndexMap()
s = pd.Series(1234567890, index=range(10000))
assert len(m._spread(s).unique()) == 1
assert m._spread(s).unique()[0] == 4072825790


def test_shift_scalar():
m = IndexMap()
assert m._shift(1.1234567890) == 1234567890


def test_shift_series():
m = IndexMap()
s = pd.Series(1.1234567890, index=range(10000))
Expand Down

0 comments on commit 25cd815

Please # to comment.