Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Sbachmei/mic 5395/mypy event #516

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.0.15 - TBD/TBD/TBD**

- Fix mypy errors in vivarium/framework/event.py

**3.0.14 - 10/18/24**

- Fix mypy errors in vivarium/framework/artifact/artifact.py
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ exclude = [
'src/vivarium/framework/components/parser.py',
'src/vivarium/framework/configuration.py',
'src/vivarium/framework/engine.py',
'src/vivarium/framework/event.py',
'src/vivarium/framework/logging/manager.py',
'src/vivarium/framework/lookup/manager.py',
'src/vivarium/framework/lookup/table.py',
Expand Down
79 changes: 45 additions & 34 deletions src/vivarium/framework/event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
"""
============================
The Vivarium Event Framework
Expand Down Expand Up @@ -28,36 +27,42 @@

"""

from typing import Any, Callable, Dict, List, NamedTuple, Optional
from __future__ import annotations

from collections.abc import Callable
from datetime import datetime
from typing import TYPE_CHECKING, Any, NamedTuple

import pandas as pd

from vivarium.framework.lifecycle import ConstraintError
from vivarium.manager import Interface, Manager
from vivarium.types import ClockStepSize, ClockTime

if TYPE_CHECKING:
from vivarium.framework.engine import Builder


class Event(NamedTuple):
"""An Event object represents the context of an event.

Events themselves are just a bundle of data. They must be emitted
along an :class:`EventChannel` in order for other simulation components
to respond to them.

"""

#: An index into the population table containing all simulants affected
#: by this event.
index: pd.Index
#: Any additional data provided by the user about the event.
user_data: Dict[str, Any]
#: The simulation time at which this event will resolve. The current
#: simulation size plus the current time step size.
# FIXME [MIC-5468]: fix index type hint for mypy
index: pd.Index[int] # type: ignore[assignment]
"""An index into the population table containing all simulants affected by this event."""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignored because NamedTuple has an index method. We can either change the name of this attr or switch to a @dataclass(frozen=True)

user_data: dict[str, Any]
"""Any additional data provided by the user about the event."""
time: ClockTime
#: The current step size at the time of the event.
"""The simulation time at which this event will resolve. The current simulation
size plus the current time step size."""
step_size: ClockStepSize
"""The current step size at the time of the event."""

def split(self, new_index: pd.Index) -> "Event":
def split(self, new_index: pd.Index[int]) -> "Event":
"""Create a copy of this event with a new index.

This function should be used to emit an event in a new
Expand All @@ -76,24 +81,24 @@ def split(self, new_index: pd.Index) -> "Event":
"""
return Event(new_index, self.user_data, self.time, self.step_size)

def __repr__(self):
def __repr__(self) -> str:
return (
f"Event(user_data={self.user_data}, time={self.time}, step_size={self.step_size})"
)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return self.__dict__ == other.__dict__


class EventChannel:
"""A named subscription channel that passes events to event listeners."""

def __init__(self, manager, name):
def __init__(self, manager: EventManager, name: str) -> None:
self.name = f"event_channel_{name}"
self.manager = manager
self.listeners = [[] for _ in range(10)]
self.listeners: list[list[Callable[[Event], None]]] = [[] for _ in range(10)]

def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event:
def emit(self, index: pd.Index[int], user_data: dict[str, Any] | None = None) -> Event:
"""Notifies all listeners to this channel that an event has occurred.

Events are emitted to listeners in order of priority (with order 0 being
Expand All @@ -113,7 +118,7 @@ def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event:
e = Event(
index,
user_data,
self.manager.clock() + self.manager.step_size(),
self.manager.clock() + self.manager.step_size(), # type: ignore[operator, arg-type]
self.manager.step_size(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to sort out the clock type hints here. They include Number which don't all have add implemented.

)

Expand All @@ -122,7 +127,7 @@ def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event:
listener(e)
return e

def __repr__(self):
def __repr__(self) -> str:
return f"EventChannel(listeners: {[listener for bucket in self.listeners for listener in bucket]})"


Expand All @@ -137,20 +142,20 @@ class EventManager(Manager):

"""

def __init__(self):
self._event_types = {}
def __init__(self) -> None:
self._event_types: dict[str, EventChannel] = {}

@property
def name(self):
def name(self) -> str:
"""The name of this component."""
return "event_manager"

def get_channel(self, name):
def get_channel(self, name: str) -> EventChannel:
if name not in self._event_types:
self._event_types[name] = EventChannel(self, name)
return self._event_types[name]

def setup(self, builder):
def setup(self, builder: Builder) -> None:
"""Performs this component's simulation setup.

Parameters
Expand All @@ -170,11 +175,13 @@ def setup(self, builder):
)
builder.lifecycle.add_constraint(self.register_listener, allow_during=["setup"])

def on_post_setup(self, event):
def on_post_setup(self, event: Event) -> None:
for name, channel in self._event_types.items():
self.add_handlers(name, [h for level in channel.listeners for h in level])

def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]:
def get_emitter(
self, name: str
) -> Callable[[pd.Index[int], dict[str, Any] | None], Event]:
"""Get an emitter function for the named event.

Parameters
Expand All @@ -197,7 +204,9 @@ def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]:
pass
return channel.emit

def register_listener(self, name: str, listener: Callable, priority: int = 5) -> None:
def register_listener(
self, name: str, listener: Callable[[Event], None], priority: int = 5
) -> None:
"""Registers a new listener to the named event.

Parameters
Expand All @@ -212,7 +221,7 @@ def register_listener(self, name: str, listener: Callable, priority: int = 5) ->
"""
self.get_channel(name).listeners[priority].append(listener)

def get_listeners(self, name: str) -> Dict[int, List[Callable]]:
def get_listeners(self, name: str) -> dict[int, list[Callable[[Event], None]]]:
"""Get all listeners registered for the named event.

Parameters
Expand All @@ -232,24 +241,24 @@ def get_listeners(self, name: str) -> Dict[int, List[Callable]]:
if listeners
}

def list_events(self) -> List[Event]:
def list_events(self) -> list[str]:
"""List all event names known to the event system.

Returns
-------
A list of all known events.
A list of all known events names.

Notes
-----
This value can change after setup if components dynamically create
new event labels.
"""
return list(self._event_types.keys())
return list(self._event_types)

def __contains__(self, item):
def __contains__(self, item: str) -> bool:
return item in self._event_types

def __repr__(self):
def __repr__(self) -> str:
return "EventManager()"


Expand All @@ -259,7 +268,9 @@ class EventInterface(Interface):
def __init__(self, manager: EventManager):
self._manager = manager

def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]:
def get_emitter(
self, name: str
) -> Callable[[pd.Index[int], dict[str, Any] | None], Event]:
"""Gets an emitter for a named event.

Parameters
Expand Down
10 changes: 10 additions & 0 deletions tests/framework/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,13 @@ def test_contains():
assert event not in manager
manager.get_emitter(event)
assert event in manager


def test_list_events():
manager = EventManager()
manager.add_constraint = lambda f, **kwargs: f
_ = manager.get_channel("event1")
_ = manager.get_emitter("event2")
_ = manager.register_listener("event3", lambda: None)
_ = manager.get_listeners("event4")
assert manager.list_events() == ["event1", "event2", "event3", "event4"]
Loading