Skip to content

Commit

Permalink
mypy fixes for boids example (#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier authored Dec 27, 2024
1 parent 6991a03 commit 8e6e3f9
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 32 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.2.13 - 12/27/24**

- Type-hinting: Fix mypy errors in vivarium/examples/boids/

**3.2.12 - 12/26/24**

- Type-hinting: Fix mypy errors in vivarium/framework/engine.py
Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ exclude = [
# You will need to remove the mypy: ignore-errors comment from the file heading as well
'docs/source/conf.py',
'setup.py',
'src/vivarium/examples/boids/forces.py',
'src/vivarium/examples/boids/movement.py',
'src/vivarium/examples/boids/neighbors.py',
'src/vivarium/examples/boids/population.py',
'src/vivarium/examples/boids/visualization.py',
'src/vivarium/examples/disease_model/__init__.py',
'src/vivarium/examples/disease_model/disease.py',
'src/vivarium/examples/disease_model/intervention.py',
Expand Down
27 changes: 15 additions & 12 deletions src/vivarium/examples/boids/forces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: ignore-errors

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any

Expand Down Expand Up @@ -43,7 +44,7 @@ def setup(self, builder: Builder) -> None:
# Pipeline sources and modifiers #
##################################

def apply_force(self, index: pd.Index, acceleration: pd.DataFrame) -> pd.DataFrame:
def apply_force(self, index: pd.Index[int], acceleration: pd.DataFrame) -> pd.DataFrame:
neighbors = self.neighbors(index)
pop = self.population_view.get(index)
pairs = self._get_pairs(neighbors, pop)
Expand All @@ -56,18 +57,18 @@ def apply_force(self, index: pd.Index, acceleration: pd.DataFrame) -> pd.DataFra
max_speed=self.max_speed,
)

acceleration.loc[force.index] += force[["x", "y"]]
acceleration.loc[force.index, ["x", "y"]] += force[["x", "y"]]
return acceleration

##################
# Helper methods #
##################

@abstractmethod
def calculate_force(self, neighbors: pd.DataFrame):
def calculate_force(self, neighbors: pd.DataFrame) -> pd.DataFrame:
pass

def _get_pairs(self, neighbors: pd.Series, pop: pd.DataFrame):
def _get_pairs(self, neighbors: pd.Series[int], pop: pd.DataFrame) -> pd.DataFrame:
pairs = (
pop.join(neighbors.rename("neighbors"))
.reset_index()
Expand All @@ -91,7 +92,7 @@ def _normalize_and_limit_force(
velocity: pd.DataFrame,
max_force: float,
max_speed: float,
):
) -> pd.DataFrame:
normalization_factor = np.where(
(force.x != 0) | (force.y != 0),
max_speed / self._magnitude(force),
Expand All @@ -111,8 +112,8 @@ def _normalize_and_limit_force(
force["y"] *= limit_scaling_factor
return force[["x", "y"]]

def _magnitude(self, df: pd.DataFrame):
return np.sqrt(np.square(df.x) + np.square(df.y))
def _magnitude(self, df: pd.DataFrame) -> pd.Series[float]:
return pd.Series(np.sqrt(np.square(df.x) + np.square(df.y)), dtype=float)


class Separation(Force):
Expand All @@ -125,7 +126,7 @@ class Separation(Force):
},
}

def calculate_force(self, neighbors: pd.DataFrame):
def calculate_force(self, neighbors: pd.DataFrame) -> pd.DataFrame:
# Push boids apart when they get too close
separation_neighbors = neighbors[neighbors.distance < self.config.distance].copy()
force_scaling_factor = np.where(
Expand All @@ -140,17 +141,19 @@ def calculate_force(self, neighbors: pd.DataFrame):
separation_neighbors["distance_y"] * force_scaling_factor
)

return (
force: pd.DataFrame = (
separation_neighbors.groupby("index_self")[["force_x", "force_y"]]
.sum()
.rename(columns=lambda c: c.replace("force_", ""))
)

return force


class Cohesion(Force):
"""Push boids together."""

def calculate_force(self, pairs: pd.DataFrame):
def calculate_force(self, pairs: pd.DataFrame) -> pd.DataFrame:
return (
pairs.groupby("index_self")[["distance_x", "distance_y"]]
.sum()
Expand All @@ -161,7 +164,7 @@ def calculate_force(self, pairs: pd.DataFrame):
class Alignment(Force):
"""Push boids toward where others are going."""

def calculate_force(self, pairs: pd.DataFrame):
def calculate_force(self, pairs: pd.DataFrame) -> pd.DataFrame:
return (
pairs.groupby("index_self")[["vx_other", "vy_other"]]
.sum()
Expand Down
7 changes: 4 additions & 3 deletions src/vivarium/examples/boids/movement.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# mypy: ignore-errors
from __future__ import annotations
import numpy as np
import pandas as pd

from vivarium.framework.event import Event
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.population import SimulantData
Expand Down Expand Up @@ -38,7 +39,7 @@ def setup(self, builder: Builder) -> None:
# Pipeline sources and modifiers #
##################################

def base_acceleration(self, index: pd.Index) -> pd.DataFrame:
def base_acceleration(self, index: pd.Index[int]) -> pd.DataFrame:
return pd.DataFrame(0.0, columns=["x", "y"], index=index)

########################
Expand All @@ -59,7 +60,7 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
)
self.population_view.update(new_population)

def on_time_step(self, event):
def on_time_step(self, event: Event) -> None:
pop = self.population_view.get(event.index)

acceleration = self.acceleration(event.index)
Expand Down
4 changes: 2 additions & 2 deletions src/vivarium/examples/boids/neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# mypy: ignore-errors
from __future__ import annotations
import pandas as pd
from scipy import spatial

Expand Down Expand Up @@ -43,7 +43,7 @@ def on_time_step(self, event: Event) -> None:
# Pipeline sources and modifiers #
##################################

def get_neighbors(self, index: pd.Index) -> pd.Series:
def get_neighbors(self, index: pd.Index[int]) -> pd.Series[list[int]]: # type: ignore[type-var]
if not self.neighbors_calculated:
self._calculate_neighbors()
return self._neighbors[index]
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/examples/boids/population.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
import numpy as np
import pandas as pd

Expand Down
19 changes: 10 additions & 9 deletions src/vivarium/examples/boids/visualization.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,43 @@
# mypy: ignore-errors
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from vivarium import InteractiveContext

def plot_boids(simulation, plot_velocity=False):

def plot_boids(simulation: InteractiveContext, plot_velocity: bool=False) -> None:
width = simulation.configuration.field.width
height = simulation.configuration.field.height
pop = simulation.get_population()

plt.figure(figsize=[12, 12])
plt.figure(figsize=(12, 12))
plt.scatter(pop.x, pop.y, color=pop.color)
if plot_velocity:
plt.quiver(pop.x, pop.y, pop.vx, pop.vy, color=pop.color, width=0.002)
plt.xlabel("x")
plt.ylabel("y")
plt.axis([0, width, 0, height])
plt.axis((0, width, 0, height))
plt.show()


def plot_boids_animated(simulation):
def plot_boids_animated(simulation: InteractiveContext) -> FuncAnimation:
width = simulation.configuration.field.width
height = simulation.configuration.field.height
pop = simulation.get_population()

fig = plt.figure(figsize=[12, 12])
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(111)
s = ax.scatter(pop.x, pop.y, color=pop.color)
plt.xlabel("x")
plt.ylabel("y")
plt.axis([0, width, 0, height])
plt.axis((0, width, 0, height))

frames = range(2_000)
frame_pops = []
for _ in frames:
simulation.step()
frame_pops.append(simulation.get_population()[["x", "y"]])

def animate(i):
def animate(i: int) -> None:
s.set_offsets(frame_pops[i])

return FuncAnimation(fig, animate, frames=frames, interval=10)
return FuncAnimation(fig, animate, frames=frames, interval=10) # type: ignore[arg-type]

0 comments on commit 8e6e3f9

Please # to comment.