Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

mypy fixes for boids example #563

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.2.12 - 12/27/24**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to merge from main again? I already released 3.2.12


- Type-hinting: Fix mypy errors in vivarium/examples/boids/

**3.2.11 - 12/23/24**

- Type-hinting: Fix mypy errors in vivarium/framework/components/parser.py
Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ exclude = [
# You will need to remove the mypy: ignore-errors comment from the file heading as well
'docs/source/conf.py',
'setup.py',
'src/vivarium/examples/boids/forces.py',
'src/vivarium/examples/boids/movement.py',
'src/vivarium/examples/boids/neighbors.py',
'src/vivarium/examples/boids/population.py',
'src/vivarium/examples/boids/visualization.py',
'src/vivarium/examples/disease_model/__init__.py',
'src/vivarium/examples/disease_model/disease.py',
'src/vivarium/examples/disease_model/intervention.py',
Expand Down
27 changes: 15 additions & 12 deletions src/vivarium/examples/boids/forces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: ignore-errors

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any

Expand Down Expand Up @@ -43,7 +44,7 @@ def setup(self, builder: Builder) -> None:
# Pipeline sources and modifiers #
##################################

def apply_force(self, index: pd.Index, acceleration: pd.DataFrame) -> pd.DataFrame:
def apply_force(self, index: pd.Index[int], acceleration: pd.DataFrame) -> pd.DataFrame:
neighbors = self.neighbors(index)
pop = self.population_view.get(index)
pairs = self._get_pairs(neighbors, pop)
Expand All @@ -56,18 +57,18 @@ def apply_force(self, index: pd.Index, acceleration: pd.DataFrame) -> pd.DataFra
max_speed=self.max_speed,
)

acceleration.loc[force.index] += force[["x", "y"]]
acceleration.loc[force.index, ["x", "y"]] += force[["x", "y"]]
return acceleration

##################
# Helper methods #
##################

@abstractmethod
def calculate_force(self, neighbors: pd.DataFrame):
def calculate_force(self, neighbors: pd.DataFrame) -> pd.DataFrame:
pass

def _get_pairs(self, neighbors: pd.Series, pop: pd.DataFrame):
def _get_pairs(self, neighbors: pd.Series[int], pop: pd.DataFrame) -> pd.DataFrame:
pairs = (
pop.join(neighbors.rename("neighbors"))
.reset_index()
Expand All @@ -91,7 +92,7 @@ def _normalize_and_limit_force(
velocity: pd.DataFrame,
max_force: float,
max_speed: float,
):
) -> pd.DataFrame:
normalization_factor = np.where(
(force.x != 0) | (force.y != 0),
max_speed / self._magnitude(force),
Expand All @@ -111,8 +112,8 @@ def _normalize_and_limit_force(
force["y"] *= limit_scaling_factor
return force[["x", "y"]]

def _magnitude(self, df: pd.DataFrame):
return np.sqrt(np.square(df.x) + np.square(df.y))
def _magnitude(self, df: pd.DataFrame) -> pd.Series[float]:
return pd.Series(np.sqrt(np.square(df.x) + np.square(df.y)), dtype=float)


class Separation(Force):
Expand All @@ -125,7 +126,7 @@ class Separation(Force):
},
}

def calculate_force(self, neighbors: pd.DataFrame):
def calculate_force(self, neighbors: pd.DataFrame) -> pd.DataFrame:
# Push boids apart when they get too close
separation_neighbors = neighbors[neighbors.distance < self.config.distance].copy()
force_scaling_factor = np.where(
Expand All @@ -140,17 +141,19 @@ def calculate_force(self, neighbors: pd.DataFrame):
separation_neighbors["distance_y"] * force_scaling_factor
)

return (
force: pd.DataFrame = (
separation_neighbors.groupby("index_self")[["force_x", "force_y"]]
.sum()
.rename(columns=lambda c: c.replace("force_", ""))
)

return force


class Cohesion(Force):
"""Push boids together."""

def calculate_force(self, pairs: pd.DataFrame):
def calculate_force(self, pairs: pd.DataFrame) -> pd.DataFrame:
return (
pairs.groupby("index_self")[["distance_x", "distance_y"]]
.sum()
Expand All @@ -161,7 +164,7 @@ def calculate_force(self, pairs: pd.DataFrame):
class Alignment(Force):
"""Push boids toward where others are going."""

def calculate_force(self, pairs: pd.DataFrame):
def calculate_force(self, pairs: pd.DataFrame) -> pd.DataFrame:
return (
pairs.groupby("index_self")[["vx_other", "vy_other"]]
.sum()
Expand Down
7 changes: 4 additions & 3 deletions src/vivarium/examples/boids/movement.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# mypy: ignore-errors
from __future__ import annotations
import numpy as np
import pandas as pd

from vivarium.framework.event import Event
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.population import SimulantData
Expand Down Expand Up @@ -38,7 +39,7 @@ def setup(self, builder: Builder) -> None:
# Pipeline sources and modifiers #
##################################

def base_acceleration(self, index: pd.Index) -> pd.DataFrame:
def base_acceleration(self, index: pd.Index[int]) -> pd.DataFrame:
return pd.DataFrame(0.0, columns=["x", "y"], index=index)

########################
Expand All @@ -59,7 +60,7 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
)
self.population_view.update(new_population)

def on_time_step(self, event):
def on_time_step(self, event: Event) -> None:
pop = self.population_view.get(event.index)

acceleration = self.acceleration(event.index)
Expand Down
4 changes: 2 additions & 2 deletions src/vivarium/examples/boids/neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# mypy: ignore-errors
from __future__ import annotations
import pandas as pd
from scipy import spatial

Expand Down Expand Up @@ -43,7 +43,7 @@ def on_time_step(self, event: Event) -> None:
# Pipeline sources and modifiers #
##################################

def get_neighbors(self, index: pd.Index) -> pd.Series:
def get_neighbors(self, index: pd.Index[int]) -> pd.Series[list[int]]: # type: ignore[type-var]
if not self.neighbors_calculated:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is that someone common pandas error where it doesn't list out the entire line of expected items:

Type argument "list[int]" of "Series" must be a subtype of "str | bytes | date | time | bool | <10 more items>"  [type-var]

self._calculate_neighbors()
return self._neighbors[index]
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/examples/boids/population.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
import numpy as np
import pandas as pd

Expand Down
19 changes: 10 additions & 9 deletions src/vivarium/examples/boids/visualization.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,43 @@
# mypy: ignore-errors
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from vivarium import InteractiveContext

def plot_boids(simulation, plot_velocity=False):

def plot_boids(simulation: InteractiveContext, plot_velocity: bool=False) -> None:
width = simulation.configuration.field.width
height = simulation.configuration.field.height
pop = simulation.get_population()

plt.figure(figsize=[12, 12])
plt.figure(figsize=(12, 12))
plt.scatter(pop.x, pop.y, color=pop.color)
if plot_velocity:
plt.quiver(pop.x, pop.y, pop.vx, pop.vy, color=pop.color, width=0.002)
plt.xlabel("x")
plt.ylabel("y")
plt.axis([0, width, 0, height])
plt.axis((0, width, 0, height))
plt.show()


def plot_boids_animated(simulation):
def plot_boids_animated(simulation: InteractiveContext) -> FuncAnimation:
width = simulation.configuration.field.width
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an InteractiveContext or a SimulationContext?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure tbh. The example docs is all about interactive, but maybe it could be used in a simulation context?

height = simulation.configuration.field.height
pop = simulation.get_population()

fig = plt.figure(figsize=[12, 12])
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(111)
s = ax.scatter(pop.x, pop.y, color=pop.color)
plt.xlabel("x")
plt.ylabel("y")
plt.axis([0, width, 0, height])
plt.axis((0, width, 0, height))

frames = range(2_000)
frame_pops = []
for _ in frames:
simulation.step()
frame_pops.append(simulation.get_population()[["x", "y"]])

def animate(i):
def animate(i: int) -> None:
s.set_offsets(frame_pops[i])

return FuncAnimation(fig, animate, frames=frames, interval=10)
return FuncAnimation(fig, animate, frames=frames, interval=10) # type: ignore[arg-type]
Loading