Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

mypy fix index_map.py #482

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.0.6 - 09/20/24**

- Fix mypy errors: vivarium/framework/randomness/index_map.py

**3.0.5 - 09/17/24**

- Pin Sphinx below 8.0
Expand Down
2 changes: 2 additions & 0 deletions docs/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ py:class pandas.core.frame.DataFrame
py:class pandas.core.series.Series
py:class pandas.core.generic.PandasObject
py:class pandas.core.groupby.generic.DataFrameGroupBy
py:class pd.Series
py:class pd.Index

#scipy
py:class scipy.stats._distn_infrastructure.rv_continuous
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ exclude = [
'src/vivarium/framework/plugins.py',
'src/vivarium/framework/population/manager.py',
'src/vivarium/framework/population/population_view.py',
'src/vivarium/framework/randomness/index_map.py',
'src/vivarium/framework/randomness/manager.py',
'src/vivarium/framework/randomness/stream.py',
'src/vivarium/framework/resource.py',
Expand Down
79 changes: 43 additions & 36 deletions src/vivarium/framework/randomness/index_map.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
"""
===================
Randomness IndexMap
Expand All @@ -11,11 +10,14 @@

"""

import datetime
from typing import List, Tuple, Union
from __future__ import annotations

from datetime import datetime
from typing import Any

import numpy as np
import pandas as pd
from pandas.api import types as pdtypes

rmudambi marked this conversation as resolved.
Show resolved Hide resolved
from vivarium.framework.randomness.exceptions import RandomnessError

Expand All @@ -26,10 +28,11 @@ class IndexMap:
SIM_INDEX_COLUMN = "simulant_index"
TEN_DIGIT_MODULUS = 10_000_000_000

def __init__(self, key_columns: List[str] = None, size: int = 1_000_000):
def __init__(self, key_columns: list[str] | None = None, size: int = 1_000_000):
self._use_crn = bool(key_columns)
self._key_columns = key_columns
self._map = None
self._key_columns = key_columns if key_columns else []
self._map: pd.Series[int] | None = None
"""The mapping between the key columns and the randomness index."""
self._size = size

def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None:
Expand Down Expand Up @@ -62,7 +65,7 @@ def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None:
final_mapping = final_mapping.sort_index(level=self.SIM_INDEX_COLUMN)
self._map = final_mapping

def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.MultiIndex]:
def _parse_new_keys(self, new_keys: pd.DataFrame) -> tuple[pd.Index[Any], pd.Index[Any]]:
"""Parses raw new keys into the mapping index.
rmudambi marked this conversation as resolved.
Show resolved Hide resolved

This returns a tuple of the new and final mapping indices. Both are pandas
Expand All @@ -88,12 +91,12 @@ def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.Mul
if self._map is None:
final_mapping_index = new_mapping_index
else:
final_mapping_index = self._map.index.append(new_mapping_index)
final_mapping_index = self._map.index.append(new_mapping_index) # type: ignore [no-untyped-call]
return new_mapping_index, final_mapping_index
rmudambi marked this conversation as resolved.
Show resolved Hide resolved

def _build_final_mapping(
self, new_mapping_index: pd.Index, clock_time: pd.Timestamp
) -> pd.Series:
self, new_mapping_index: pd.Index[Any], clock_time: pd.Timestamp
) -> pd.Series[int]:
"""Builds a new mapping between key columns and the randomness index from the
new mapping index and the existing map.

Expand Down Expand Up @@ -123,9 +126,9 @@ def _build_final_mapping(

def _resolve_collisions(
self,
new_key_index: pd.MultiIndex,
current_mapping: pd.Series,
) -> pd.Series:
new_key_index: pd.Index[Any],
current_mapping: pd.Series[int],
) -> pd.Series[int]:
"""Resolves collisions in the new mapping by perturbing the hash.

Parameters
Expand All @@ -151,15 +154,15 @@ def _resolve_collisions(
salt += 1
return current_mapping

def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series:
def _hash(self, keys: pd.Index[Any], salt: int | pd.Timestamp = 0) -> pd.Series[int]:
"""Hashes the index into an integer index in the range [0, self.stride]

Parameters
----------
keys
The new index to hash.
salt
An integer used to perturb the hash in a deterministic way. Useful
Value used to perturb the hash in a deterministic way. Useful
in dealing with collisions.

Returns
Expand All @@ -170,9 +173,9 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series:
"""
key_frame = keys.to_frame()
new_map = pd.Series(0, index=keys)
salt = self._convert_to_ten_digit_int(pd.Series(salt, index=keys))
salt_series = self._convert_to_ten_digit_int(pd.Series(salt, index=keys))

for i, column_name in enumerate(key_frame.columns):
for _i, column_name in enumerate(key_frame.columns):
column = self._convert_to_ten_digit_int(key_frame[column_name])

primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 27]
Expand All @@ -183,11 +186,13 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series:
# our map size the amount of additional periodicity this
# introduces is pretty trivial.
out *= np.power(p, self._digit(column, idx))
new_map += out + salt
new_map += out + salt_series

return new_map % len(self)

def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series:
def _convert_to_ten_digit_int(
self, column: pd.Series[datetime | int | float]
) -> pd.Series[int]:
"""Converts a column of datetimes, integers, or floats into a column
of 10 digit integers.

Expand All @@ -206,47 +211,49 @@ def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series:
If the column contains data that is neither a datetime-like nor
numeric.
"""
if isinstance(column.iloc[0], datetime.datetime):
column = self._clip_to_seconds(column.astype(np.int64))
elif np.issubdtype(column.iloc[0], np.integer):
if pdtypes.is_datetime64_any_dtype(column):
integers = self._clip_to_seconds(column.astype(np.int64))
elif pdtypes.is_integer_dtype(column):
if not len(column >= 0) == len(column):
raise RandomnessError(
"Values in integer columns must be greater than or equal to zero."
)
column = self._spread(column)
elif np.issubdtype(column.iloc[0], np.floating):
column = self._shift(column)
integers = self._spread(column.astype(int))
elif pdtypes.is_float_dtype(column):
integers = self._shift(column.astype(float))
else:
raise RandomnessError(
f"Unhashable column type {type(column.iloc[0])}. "
"IndexMap accepts datetime like columns and numeric columns."
)
return column
return integers

@staticmethod
def _digit(m: Union[int, pd.Series], n: int) -> Union[int, pd.Series]:
def _digit(m: pd.Series[int], n: int) -> pd.Series[int]:
"""Returns the nth digit of each number in m."""
return (m // (10**n)) % 10
nth_digits: pd.Series[int] = (m // (10**n)) % 10
return nth_digits

@staticmethod
def _clip_to_seconds(m: Union[int, pd.Series]) -> Union[int, pd.Series]:
def _clip_to_seconds(m: pd.Series[int]) -> pd.Series[int]:
"""Clips UTC datetime in nanoseconds to seconds."""
return m // pd.Timedelta(1, unit="s").value

def _spread(self, m: Union[int, pd.Series]) -> Union[int, pd.Series]:
def _spread(self, m: pd.Series[int]) -> pd.Series[int]:
"""Spreads out integer values to give smaller values more weight."""
return (m * 111_111) % self.TEN_DIGIT_MODULUS

def _shift(self, m: Union[float, pd.Series]) -> Union[int, pd.Series]:
def _shift(self, m: pd.Series[float]) -> pd.Series[int]:
"""Shifts floats so that the first 10 decimal digits are significant."""
out = m % 1 * self.TEN_DIGIT_MODULUS // 1
if isinstance(out, pd.Series):
return out.astype("int64")
return int(out)
return out.astype("int64")

def __getitem__(self, index: pd.Index) -> np.ndarray:
def __getitem__(self, index: pd.Index[int]) -> np.ndarray[int, Any]:
if self._use_crn:
return self._map.loc[index].values
if self._map is None:
raise RandomnessError("IndexMap is empty")
else:
return self._map.loc[index].to_numpy()
else:
return index.values

Expand Down
17 changes: 0 additions & 17 deletions tests/framework/randomness/test_index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ def map_size_and_hashed_values(request):
return len(m), m._hash(keys)


def test_digit_scalar():
rmudambi marked this conversation as resolved.
Show resolved Hide resolved
m = IndexMap()
k = 123456789
for i in range(10):
assert m._digit(k, i) == 10 - (i + 1)


def test_digit_series():
m = IndexMap()
k = pd.Series(123456789, index=range(10000))
Expand All @@ -94,23 +87,13 @@ def test_clip_to_seconds_series():
assert m._clip_to_seconds(k).unique()[0] == stamp


def test_spread_scalar():
m = IndexMap()
assert m._spread(1234567890) == 4072825790


def test_spread_series():
m = IndexMap()
s = pd.Series(1234567890, index=range(10000))
assert len(m._spread(s).unique()) == 1
assert m._spread(s).unique()[0] == 4072825790


def test_shift_scalar():
m = IndexMap()
assert m._shift(1.1234567890) == 1234567890


def test_shift_series():
m = IndexMap()
s = pd.Series(1.1234567890, index=range(10000))
Expand Down
Loading