Skip to content

Commit

Permalink
Sbachmei/mic 5395/mypy event (#516)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier authored Oct 24, 2024
1 parent 69cf954 commit a42e397
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 35 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.0.15 - TBD/TBD/TBD**

- Fix mypy errors in vivarium/framework/event.py

**3.0.14 - 10/18/24**

- Fix mypy errors in vivarium/framework/artifact/artifact.py
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ exclude = [
'src/vivarium/framework/components/parser.py',
'src/vivarium/framework/configuration.py',
'src/vivarium/framework/engine.py',
'src/vivarium/framework/event.py',
'src/vivarium/framework/logging/manager.py',
'src/vivarium/framework/lookup/manager.py',
'src/vivarium/framework/lookup/table.py',
Expand Down
79 changes: 45 additions & 34 deletions src/vivarium/framework/event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
"""
============================
The Vivarium Event Framework
Expand Down Expand Up @@ -28,36 +27,42 @@
"""

from typing import Any, Callable, Dict, List, NamedTuple, Optional
from __future__ import annotations

from collections.abc import Callable
from datetime import datetime
from typing import TYPE_CHECKING, Any, NamedTuple

import pandas as pd

from vivarium.framework.lifecycle import ConstraintError
from vivarium.manager import Interface, Manager
from vivarium.types import ClockStepSize, ClockTime

if TYPE_CHECKING:
from vivarium.framework.engine import Builder


class Event(NamedTuple):
"""An Event object represents the context of an event.
Events themselves are just a bundle of data. They must be emitted
along an :class:`EventChannel` in order for other simulation components
to respond to them.
"""

#: An index into the population table containing all simulants affected
#: by this event.
index: pd.Index
#: Any additional data provided by the user about the event.
user_data: Dict[str, Any]
#: The simulation time at which this event will resolve. The current
#: simulation size plus the current time step size.
# FIXME [MIC-5468]: fix index type hint for mypy
index: pd.Index[int] # type: ignore[assignment]
"""An index into the population table containing all simulants affected by this event."""
user_data: dict[str, Any]
"""Any additional data provided by the user about the event."""
time: ClockTime
#: The current step size at the time of the event.
"""The simulation time at which this event will resolve. The current simulation
size plus the current time step size."""
step_size: ClockStepSize
"""The current step size at the time of the event."""

def split(self, new_index: pd.Index) -> "Event":
def split(self, new_index: pd.Index[int]) -> "Event":
"""Create a copy of this event with a new index.
This function should be used to emit an event in a new
Expand All @@ -76,24 +81,24 @@ def split(self, new_index: pd.Index) -> "Event":
"""
return Event(new_index, self.user_data, self.time, self.step_size)

def __repr__(self):
def __repr__(self) -> str:
return (
f"Event(user_data={self.user_data}, time={self.time}, step_size={self.step_size})"
)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return self.__dict__ == other.__dict__


class EventChannel:
"""A named subscription channel that passes events to event listeners."""

def __init__(self, manager, name):
def __init__(self, manager: EventManager, name: str) -> None:
self.name = f"event_channel_{name}"
self.manager = manager
self.listeners = [[] for _ in range(10)]
self.listeners: list[list[Callable[[Event], None]]] = [[] for _ in range(10)]

def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event:
def emit(self, index: pd.Index[int], user_data: dict[str, Any] | None = None) -> Event:
"""Notifies all listeners to this channel that an event has occurred.
Events are emitted to listeners in order of priority (with order 0 being
Expand All @@ -113,7 +118,7 @@ def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event:
e = Event(
index,
user_data,
self.manager.clock() + self.manager.step_size(),
self.manager.clock() + self.manager.step_size(), # type: ignore[operator, arg-type]
self.manager.step_size(),
)

Expand All @@ -122,7 +127,7 @@ def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event:
listener(e)
return e

def __repr__(self):
def __repr__(self) -> str:
return f"EventChannel(listeners: {[listener for bucket in self.listeners for listener in bucket]})"


Expand All @@ -137,20 +142,20 @@ class EventManager(Manager):
"""

def __init__(self):
self._event_types = {}
def __init__(self) -> None:
self._event_types: dict[str, EventChannel] = {}

@property
def name(self):
def name(self) -> str:
"""The name of this component."""
return "event_manager"

def get_channel(self, name):
def get_channel(self, name: str) -> EventChannel:
if name not in self._event_types:
self._event_types[name] = EventChannel(self, name)
return self._event_types[name]

def setup(self, builder):
def setup(self, builder: Builder) -> None:
"""Performs this component's simulation setup.
Parameters
Expand All @@ -170,11 +175,13 @@ def setup(self, builder):
)
builder.lifecycle.add_constraint(self.register_listener, allow_during=["setup"])

def on_post_setup(self, event):
def on_post_setup(self, event: Event) -> None:
for name, channel in self._event_types.items():
self.add_handlers(name, [h for level in channel.listeners for h in level])

def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]:
def get_emitter(
self, name: str
) -> Callable[[pd.Index[int], dict[str, Any] | None], Event]:
"""Get an emitter function for the named event.
Parameters
Expand All @@ -197,7 +204,9 @@ def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]:
pass
return channel.emit

def register_listener(self, name: str, listener: Callable, priority: int = 5) -> None:
def register_listener(
self, name: str, listener: Callable[[Event], None], priority: int = 5
) -> None:
"""Registers a new listener to the named event.
Parameters
Expand All @@ -212,7 +221,7 @@ def register_listener(self, name: str, listener: Callable, priority: int = 5) ->
"""
self.get_channel(name).listeners[priority].append(listener)

def get_listeners(self, name: str) -> Dict[int, List[Callable]]:
def get_listeners(self, name: str) -> dict[int, list[Callable[[Event], None]]]:
"""Get all listeners registered for the named event.
Parameters
Expand All @@ -232,24 +241,24 @@ def get_listeners(self, name: str) -> Dict[int, List[Callable]]:
if listeners
}

def list_events(self) -> List[Event]:
def list_events(self) -> list[str]:
"""List all event names known to the event system.
Returns
-------
A list of all known events.
A list of all known events names.
Notes
-----
This value can change after setup if components dynamically create
new event labels.
"""
return list(self._event_types.keys())
return list(self._event_types)

def __contains__(self, item):
def __contains__(self, item: str) -> bool:
return item in self._event_types

def __repr__(self):
def __repr__(self) -> str:
return "EventManager()"


Expand All @@ -259,7 +268,9 @@ class EventInterface(Interface):
def __init__(self, manager: EventManager):
self._manager = manager

def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]:
def get_emitter(
self, name: str
) -> Callable[[pd.Index[int], dict[str, Any] | None], Event]:
"""Gets an emitter for a named event.
Parameters
Expand Down
10 changes: 10 additions & 0 deletions tests/framework/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,13 @@ def test_contains():
assert event not in manager
manager.get_emitter(event)
assert event in manager


def test_list_events():
manager = EventManager()
manager.add_constraint = lambda f, **kwargs: f
_ = manager.get_channel("event1")
_ = manager.get_emitter("event2")
_ = manager.register_listener("event3", lambda: None)
_ = manager.get_listeners("event4")
assert manager.list_events() == ["event1", "event2", "event3", "event4"]

0 comments on commit a42e397

Please # to comment.