From a42e397d10ea15b929503de8f60a5a84e6cdef9a Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:48:43 -0600 Subject: [PATCH] Sbachmei/mic 5395/mypy event (#516) --- CHANGELOG.rst | 4 ++ pyproject.toml | 1 - src/vivarium/framework/event.py | 79 +++++++++++++++++++-------------- tests/framework/test_event.py | 10 +++++ 4 files changed, 59 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0a8d691fa..4b92d4514 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.0.15 - TBD/TBD/TBD** + + - Fix mypy errors in vivarium/framework/event.py + **3.0.14 - 10/18/24** - Fix mypy errors in vivarium/framework/artifact/artifact.py diff --git a/pyproject.toml b/pyproject.toml index 323013677..3609732da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ exclude = [ 'src/vivarium/framework/components/parser.py', 'src/vivarium/framework/configuration.py', 'src/vivarium/framework/engine.py', - 'src/vivarium/framework/event.py', 'src/vivarium/framework/logging/manager.py', 'src/vivarium/framework/lookup/manager.py', 'src/vivarium/framework/lookup/table.py', diff --git a/src/vivarium/framework/event.py b/src/vivarium/framework/event.py index 29ca84ec1..d5805970b 100644 --- a/src/vivarium/framework/event.py +++ b/src/vivarium/framework/event.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ ============================ The Vivarium Event Framework @@ -28,7 +27,11 @@ """ -from typing import Any, Callable, Dict, List, NamedTuple, Optional +from __future__ import annotations + +from collections.abc import Callable +from datetime import datetime +from typing import TYPE_CHECKING, Any, NamedTuple import pandas as pd @@ -36,6 +39,9 @@ from vivarium.manager import Interface, Manager from vivarium.types import ClockStepSize, ClockTime +if TYPE_CHECKING: + from vivarium.framework.engine import Builder + class Event(NamedTuple): """An Event object represents the context of an event. @@ -43,21 +49,20 @@ class Event(NamedTuple): Events themselves are just a bundle of data. They must be emitted along an :class:`EventChannel` in order for other simulation components to respond to them. - """ - #: An index into the population table containing all simulants affected - #: by this event. - index: pd.Index - #: Any additional data provided by the user about the event. - user_data: Dict[str, Any] - #: The simulation time at which this event will resolve. The current - #: simulation size plus the current time step size. + # FIXME [MIC-5468]: fix index type hint for mypy + index: pd.Index[int] # type: ignore[assignment] + """An index into the population table containing all simulants affected by this event.""" + user_data: dict[str, Any] + """Any additional data provided by the user about the event.""" time: ClockTime - #: The current step size at the time of the event. + """The simulation time at which this event will resolve. The current simulation + size plus the current time step size.""" step_size: ClockStepSize + """The current step size at the time of the event.""" - def split(self, new_index: pd.Index) -> "Event": + def split(self, new_index: pd.Index[int]) -> "Event": """Create a copy of this event with a new index. This function should be used to emit an event in a new @@ -76,24 +81,24 @@ def split(self, new_index: pd.Index) -> "Event": """ return Event(new_index, self.user_data, self.time, self.step_size) - def __repr__(self): + def __repr__(self) -> str: return ( f"Event(user_data={self.user_data}, time={self.time}, step_size={self.step_size})" ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self.__dict__ == other.__dict__ class EventChannel: """A named subscription channel that passes events to event listeners.""" - def __init__(self, manager, name): + def __init__(self, manager: EventManager, name: str) -> None: self.name = f"event_channel_{name}" self.manager = manager - self.listeners = [[] for _ in range(10)] + self.listeners: list[list[Callable[[Event], None]]] = [[] for _ in range(10)] - def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event: + def emit(self, index: pd.Index[int], user_data: dict[str, Any] | None = None) -> Event: """Notifies all listeners to this channel that an event has occurred. Events are emitted to listeners in order of priority (with order 0 being @@ -113,7 +118,7 @@ def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event: e = Event( index, user_data, - self.manager.clock() + self.manager.step_size(), + self.manager.clock() + self.manager.step_size(), # type: ignore[operator, arg-type] self.manager.step_size(), ) @@ -122,7 +127,7 @@ def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event: listener(e) return e - def __repr__(self): + def __repr__(self) -> str: return f"EventChannel(listeners: {[listener for bucket in self.listeners for listener in bucket]})" @@ -137,20 +142,20 @@ class EventManager(Manager): """ - def __init__(self): - self._event_types = {} + def __init__(self) -> None: + self._event_types: dict[str, EventChannel] = {} @property - def name(self): + def name(self) -> str: """The name of this component.""" return "event_manager" - def get_channel(self, name): + def get_channel(self, name: str) -> EventChannel: if name not in self._event_types: self._event_types[name] = EventChannel(self, name) return self._event_types[name] - def setup(self, builder): + def setup(self, builder: Builder) -> None: """Performs this component's simulation setup. Parameters @@ -170,11 +175,13 @@ def setup(self, builder): ) builder.lifecycle.add_constraint(self.register_listener, allow_during=["setup"]) - def on_post_setup(self, event): + def on_post_setup(self, event: Event) -> None: for name, channel in self._event_types.items(): self.add_handlers(name, [h for level in channel.listeners for h in level]) - def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]: + def get_emitter( + self, name: str + ) -> Callable[[pd.Index[int], dict[str, Any] | None], Event]: """Get an emitter function for the named event. Parameters @@ -197,7 +204,9 @@ def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]: pass return channel.emit - def register_listener(self, name: str, listener: Callable, priority: int = 5) -> None: + def register_listener( + self, name: str, listener: Callable[[Event], None], priority: int = 5 + ) -> None: """Registers a new listener to the named event. Parameters @@ -212,7 +221,7 @@ def register_listener(self, name: str, listener: Callable, priority: int = 5) -> """ self.get_channel(name).listeners[priority].append(listener) - def get_listeners(self, name: str) -> Dict[int, List[Callable]]: + def get_listeners(self, name: str) -> dict[int, list[Callable[[Event], None]]]: """Get all listeners registered for the named event. Parameters @@ -232,24 +241,24 @@ def get_listeners(self, name: str) -> Dict[int, List[Callable]]: if listeners } - def list_events(self) -> List[Event]: + def list_events(self) -> list[str]: """List all event names known to the event system. Returns ------- - A list of all known events. + A list of all known events names. Notes ----- This value can change after setup if components dynamically create new event labels. """ - return list(self._event_types.keys()) + return list(self._event_types) - def __contains__(self, item): + def __contains__(self, item: str) -> bool: return item in self._event_types - def __repr__(self): + def __repr__(self) -> str: return "EventManager()" @@ -259,7 +268,9 @@ class EventInterface(Interface): def __init__(self, manager: EventManager): self._manager = manager - def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]: + def get_emitter( + self, name: str + ) -> Callable[[pd.Index[int], dict[str, Any] | None], Event]: """Gets an emitter for a named event. Parameters diff --git a/tests/framework/test_event.py b/tests/framework/test_event.py index e8cc8112c..c491de33b 100644 --- a/tests/framework/test_event.py +++ b/tests/framework/test_event.py @@ -122,3 +122,13 @@ def test_contains(): assert event not in manager manager.get_emitter(event) assert event in manager + + +def test_list_events(): + manager = EventManager() + manager.add_constraint = lambda f, **kwargs: f + _ = manager.get_channel("event1") + _ = manager.get_emitter("event2") + _ = manager.register_listener("event3", lambda: None) + _ = manager.get_listeners("event4") + assert manager.list_events() == ["event1", "event2", "event3", "event4"]