From 92e8b3661715c593915720a0b6d66d4f36737837 Mon Sep 17 00:00:00 2001 From: PietroPasotti Date: Thu, 23 Jun 2022 13:39:29 +0200 Subject: [PATCH] Typing framework.py (#772) --- ops/__init__.py | 4 +- ops/charm.py | 7 +- ops/framework.py | 454 +++++++++++++++++++++++++++-------------------- pyproject.toml | 9 +- 4 files changed, 278 insertions(+), 196 deletions(-) diff --git a/ops/__init__.py b/ops/__init__.py index 44cb77a44..5dbba5327 100644 --- a/ops/__init__.py +++ b/ops/__init__.py @@ -40,5 +40,5 @@ """ # Import here the bare minimum to break the circular import between modules -from . import charm # noqa: F401 (imported but unused) -from .version import version as __version__ # noqa: F401 (imported but unused) +from . import charm # type: ignore # noqa +from .version import version as __version__ # type: ignore # noqa diff --git a/ops/charm.py b/ops/charm.py index eb1a5b757..d72ec8130 100755 --- a/ops/charm.py +++ b/ops/charm.py @@ -18,6 +18,7 @@ import os import pathlib import typing +from typing import TYPE_CHECKING from ops import model from ops._private import yaml @@ -669,8 +670,12 @@ def __init__(self, *args): # note that without the #: below, sphinx will copy the whole of CharmEvents # docstring inline which is less than ideal. - #: Used to set up event handlers; see :class:`CharmEvents`. + # Used to set up event handlers; see :class:`CharmEvents`. on = CharmEvents() + if TYPE_CHECKING: + # to help the type checker and IDEs: + @property + def on(self) -> CharmEvents: ... # noqa def __init__(self, framework: Framework, key: typing.Optional = None): super().__init__(framework, None) diff --git a/ops/framework.py b/ops/framework.py index 5cb9ad035..5ddae674c 100755 --- a/ops/framework.py +++ b/ops/framework.py @@ -28,40 +28,80 @@ import types import typing import weakref +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + Hashable, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from ops import charm from ops.storage import NoSnapshotError, SQLiteStorage -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from pathlib import Path - from typing_extensions import Protocol, Type + from typing_extensions import Literal, Protocol, Type from ops.charm import CharmMeta - from ops.model import Model + from ops.model import Model, _ModelBackend class _Serializable(Protocol): - handle = None # type: Handle - def snapshot(self) -> dict: ... # noqa: E704 - def restore(self, snapshot: dict) -> "Object": ... # noqa: E704 + handle_kind = '' + @property + def handle(self) -> 'Handle': ... # noqa + @handle.setter + def handle(self, val: 'Handle'): ... # noqa + def snapshot(self) -> Dict[str, '_StorableType']: ... # noqa + def restore(self, snapshot: Dict[str, '_StorableType']) -> None: ... # noqa - _ObjectType = typing.TypeVar("_ObjectType", bound="Object") - _EventType = typing.TypeVar("_EventType", bound=Type["EventBase"]) - _ObserverCallback = typing.Callable[[typing.Any], None] - _Path = _Kind = str + class _StoredObject(Protocol): + _under = None # type: Any # noqa + + # all types that can be (de) serialized to json(/yaml) fom Python builtins + JsonObject = Union[int, float, bool, str, + Dict[str, 'JsonObject'], + List['JsonObject'], + Tuple['JsonObject', ...]] + + # serialized data structure + _SerializedData = Dict[str, 'JsonObject'] + + _ObserverCallback = Callable[[Any], None] + + # types that can be stored natively + _StorableType = Union[int, float, str, bytes, Literal[None], + List['_StorableType'], + Dict[str, '_StorableType'], + Set['_StorableType']] + + StoredObject = Union['StoredList', 'StoredSet', 'StoredDict'] # This type is used to denote either a Handle instance or an instance of # an Object (or subclass). This is used by methods and classes which can be # called with either of those (they need a Handle, but will accept an Object # from which they will then extract the Handle). - _ParentHandle = typing.Union["Handle", _ObjectType] + _ParentHandle = Union['Handle', 'Object'] + _Path = _Kind = _MethodName = _EventKey = str # used to type Framework Attributes - _ObserverPath = typing.List[typing.Tuple['_Path', str, '_Path', str]] - _ObjectPath = typing.Tuple[typing.Optional['_Path'], '_Kind'] - _PathToObserverMapping = typing.Dict[str, '_ObserverCallback'] - _PathToObjectMapping = typing.Dict[str, 'Object'] + _ObserverPath = List[Tuple[_Path, _MethodName, _Path, _EventKey]] + _ObjectPath = Tuple[Optional[_Path], _Kind] + _PathToObjectMapping = Dict[_Path, 'Object'] + _PathToSerializableMapping = Dict[_Path, _Serializable] +_T = TypeVar("_T") +_EventType = TypeVar('_EventType', bound='EventBase') +_ObjectType = TypeVar("_ObjectType", bound="Object") logger = logging.getLogger(__name__) @@ -80,11 +120,11 @@ class Handle: under the same parent and kind may have the same key. """ - def __init__(self, parent: typing.Optional["_ParentHandle"], kind: str, key: str): + def __init__(self, parent: Optional[Union['Handle', 'Object']], kind: str, key: str): if isinstance(parent, Object): # if it's not an Object, it will be either a Handle (good) or None (no parent) parent = parent.handle - self._parent = parent + self._parent = parent # type: Optional[Handle] self._kind = kind self._key = key if parent: @@ -98,31 +138,31 @@ def __init__(self, parent: typing.Optional["_ParentHandle"], kind: str, key: str else: self._path = "{}".format(kind) - def nest(self, kind: str, key: str): + def nest(self, kind: str, key: str) -> 'Handle': """Create a new handle as child of the current one.""" return Handle(self, kind, key) def __hash__(self): return hash((self.parent, self.kind, self.key)) - def __eq__(self, other): + def __eq__(self, other: 'Handle'): return (self.parent, self.kind, self.key) == (other.parent, other.kind, other.key) def __str__(self): return self.path @property - def parent(self): + def parent(self) -> Optional['Handle']: """Return own parent handle.""" return self._parent @property - def kind(self): + def kind(self) -> str: """Return the handle's kind.""" return self._kind @property - def key(self): + def key(self) -> str: """Return the handle's key.""" return self._key @@ -132,7 +172,7 @@ def path(self): return self._path @classmethod - def from_path(cls, path): + def from_path(cls, path: str) -> 'Handle': """Build a handle from the indicated path.""" handle = None for pair in path.split("/"): @@ -148,7 +188,7 @@ def from_path(cls, path): good = True if not good: raise RuntimeError("attempted to restore invalid handle path {}".format(path)) - handle = Handle(handle, kind, key) # noqa + handle = Handle(handle, kind, key) # pyright: reportUnboundVariable=false return handle @@ -224,14 +264,14 @@ def defer(self): logger.debug("Deferring %s.", self) self.deferred = True - def snapshot(self) -> dict: + def snapshot(self) -> '_SerializedData': """Return the snapshot data that should be persisted. Subclasses must override to save any custom state. """ - return None + return {} - def restore(self, snapshot): + def restore(self, snapshot: '_SerializedData'): """Restore the value state from the given snapshot. Subclasses must override to restore their custom state. @@ -239,7 +279,7 @@ def restore(self, snapshot): self.deferred = False -class EventSource: +class EventSource(Generic[_EventType]): """EventSource wraps an event type with a descriptor to facilitate observing and emitting. It is generally used as: @@ -254,20 +294,21 @@ class SomeObject(Object): attribute which is a BoundEvent and may be used to emit and observe the event. """ - def __init__(self, event_type): + def __init__(self, event_type: 'Type[_EventType]'): if not isinstance(event_type, type) or not issubclass(event_type, EventBase): raise RuntimeError( 'Event requires a subclass of EventBase as an argument, got {}'.format(event_type)) - self.event_type = event_type - self.event_kind = None - self.emitter_type = None + self.event_type = event_type # type: Type[_EventType] + self.event_kind = None # type: Optional[str] # noqa + self.emitter_type = None # type: Optional[Type[Object]] # noqa - def _set_name(self, emitter_type, event_kind): + def _set_name(self, emitter_type: 'Type[Object]', event_kind: str): if self.event_kind is not None: raise RuntimeError( 'EventSource({}) reused as {}.{} and {}.{}'.format( self.event_type.__name__, - self.emitter_type.__name__, + # emitter_type could still be None + getattr(self.emitter_type, '__name__', self.emitter_type), self.event_kind, emitter_type.__name__, event_kind, @@ -275,9 +316,11 @@ def _set_name(self, emitter_type, event_kind): self.event_kind = event_kind self.emitter_type = emitter_type - def __get__(self, emitter, emitter_type=None): + def __get__(self, emitter: Optional['Object'], + emitter_type: 'Type[Object]' + ) -> 'BoundEvent[_EventType]': if emitter is None: - return self + return self # type: ignore # Framework might not be available if accessed as CharmClass.on.event # rather than charm_instance.on.event, but in that case it couldn't be # emitted anyway, so there's no point to registering it. @@ -287,7 +330,7 @@ def __get__(self, emitter, emitter_type=None): return BoundEvent(emitter, self.event_type, self.event_kind) -class BoundEvent: +class BoundEvent(Generic[_EventType]): """Event bound to an Object.""" def __repr__(self): @@ -298,22 +341,23 @@ def __repr__(self): hex(id(self)), ) - def __init__(self, emitter: "_ObjectType", - event_type: "_EventType", event_kind: str): + def __init__(self, emitter: 'Object', + event_type: 'Type[EventBase]', + event_kind: str): self.emitter = emitter self.event_type = event_type self.event_kind = event_kind - def emit(self, *args, **kwargs): + def emit(self, *args: Any, **kwargs: Any): """Emit event to all registered observers. The current storage state is committed before and after each observer is notified. """ framework = self.emitter.framework - key = framework._next_event_key() + key = framework._next_event_key() # noqa event = self.event_type(Handle(self.emitter, self.event_kind, key), *args, **kwargs) event.framework = framework - framework._emit(event) + framework._emit(event) # noqa class HandleKind: @@ -323,8 +367,8 @@ class HandleKind: be explicitly overridden if desired. """ - def __get__(self, obj, obj_type): - kind = obj_type.__dict__.get("handle_kind") + def __get__(self, obj: 'Object', obj_type: 'Type[Object]') -> str: + kind = typing.cast(str, obj_type.__dict__.get("handle_kind")) if kind: return kind return obj_type.__name__ @@ -350,15 +394,15 @@ class SomeObject(Object): """ - def __new__(cls, *a, **kw): - k = super().__new__(cls, *a, **kw) + def __new__(cls, *a, **kw): # type: ignore + k = super().__new__(cls, *a, **kw) # type: ignore # k is now the Object-derived class; loop over its class attributes for n, v in vars(k).items(): # we could do duck typing here if we want to support # non-EventSource-derived shenanigans. We don't. if isinstance(v, EventSource): # this is what 3.6+ does automatically for us: - v._set_name(k, n) + v._set_name(k, n) # noqa return k @@ -381,11 +425,18 @@ class Object(metaclass=_Metaclass): been created. """ - framework = None # type: Framework - handle = None # type: Handle handle_kind = HandleKind() # type: str - def __init__(self, parent, key): + if TYPE_CHECKING: + # to help the type checker and IDEs: + # all these are guaranteed to be set at runtime. + @property + def on(self) -> 'ObjectEvents': ... # noqa + + def __init__(self, parent: Union['Framework', 'Object'], key: Optional[str]): + self.framework = None # type: Framework # noqa + self.handle = None # type: Handle # noqa + kind = self.handle_kind if isinstance(parent, Framework): self.framework = parent @@ -396,12 +447,12 @@ def __init__(self, parent, key): else: self.framework = parent.framework self.handle = Handle(parent, kind, key) - self.framework._track(self) + self.framework._track(self) # noqa # TODO Detect conflicting handles here. @property - def model(self) -> "Model": + def model(self) -> 'Model': """Shortcut for more simple access the model.""" return self.framework.model @@ -411,13 +462,12 @@ class ObjectEvents(Object): handle_kind = "on" - def __init__(self, parent=None, key=None): + def __init__(self, parent: Optional[Object] = None, key: Optional[str] = None): if parent is not None: super().__init__(parent, key) - else: - self._cache = weakref.WeakKeyDictionary() + self._cache = weakref.WeakKeyDictionary() # type: weakref.WeakKeyDictionary[Object, 'ObjectEvents'] # noqa - def __get__(self, emitter, emitter_type): + def __get__(self, emitter: Object, emitter_type: 'Type[Object]'): if emitter is None: return self instance = self._cache.get(emitter) @@ -428,7 +478,7 @@ def __get__(self, emitter, emitter_type): return instance @classmethod - def define_event(cls, event_kind, event_type): + def define_event(cls, event_kind: str, event_type: 'Type[EventBase]'): """Define an event on this type at runtime. cls: a type to define an event on. @@ -453,11 +503,11 @@ def define_event(cls, event_kind, event_type): pass event_descriptor = EventSource(event_type) - event_descriptor._set_name(cls, event_kind) + event_descriptor._set_name(cls, event_kind) # noqa setattr(cls, event_kind, event_descriptor) - def _event_kinds(self): - event_kinds = [] + def _event_kinds(self) -> List[str]: + event_kinds = [] # type: List[str] # We have to iterate over the class rather than instance to allow for properties which # might call this method (e.g., event views), leading to infinite recursion. for attr_name, attr_value in inspect.getmembers(type(self)): @@ -467,11 +517,11 @@ def _event_kinds(self): event_kinds.append(attr_name) return event_kinds - def events(self): + def events(self) -> Dict[str, EventSource[EventBase]]: """Return a mapping of event_kinds to bound_events for all available events.""" return {event_kind: getattr(self, event_kind) for event_kind in self._event_kinds()} - def __getitem__(self, key): + def __getitem__(self, key: str) -> 'PrefixedEvents': return PrefixedEvents(self, key) def __repr__(self): @@ -483,11 +533,11 @@ def __repr__(self): class PrefixedEvents: """Events to be found in all events using a specific prefix.""" - def __init__(self, emitter, key): + def __init__(self, emitter: Object, key: str): self._emitter = emitter self._prefix = key.replace("-", "_") + '_' - def __getattr__(self, name): + def __getattr__(self, name: str) -> Union['PrefixedEvents', EventSource[Any]]: return getattr(self._emitter, self._prefix + name) @@ -508,7 +558,7 @@ class FrameworkEvents(ObjectEvents): class NoTypeError(Exception): """No class to hold it was found when restoring an event.""" - def __init__(self, handle_path): + def __init__(self, handle_path: str): self.handle_path = handle_path def __str__(self): @@ -533,16 +583,19 @@ class Framework(Object): on = FrameworkEvents() # Override properties from Object so that we can set them in __init__. - model = None # type: 'Model' - meta = None # type: 'CharmMeta' - charm_dir = None # type: 'Path' + model = None # type: 'Model' # pyright: reportGeneralTypeIssues=false + meta = None # type: 'CharmMeta' # pyright: reportGeneralTypeIssues=false + charm_dir = None # type: 'Path' # pyright: reportGeneralTypeIssues=false - if typing.TYPE_CHECKING: - # to help the type checker and IDEs: + # to help the type checker and IDEs: + + if TYPE_CHECKING: _stored = None # type: 'StoredStateData' + @property + def on(self) -> 'FrameworkEvents': ... # noqa - def __init__(self, storage: SQLiteStorage, charm_dir: "Path", - meta: "CharmMeta", model: "Model"): + def __init__(self, storage: SQLiteStorage, charm_dir: 'Path', + meta: 'CharmMeta', model: 'Model'): super().__init__(self, None) self.charm_dir = charm_dir @@ -550,15 +603,15 @@ def __init__(self, storage: SQLiteStorage, charm_dir: "Path", self.model = model # [(observer_path, method_name, parent_path, event_key)] self._observers = [] # type: _ObserverPath - # {observer_path: observer} - self._observer = weakref.WeakValueDictionary() # type: _PathToObserverMapping + # {observer_path: observing Object} + self._observer = weakref.WeakValueDictionary() # type: _PathToObjectMapping # noqa # {object_path: object} - self._objects = weakref.WeakValueDictionary() # type: _PathToObjectMapping + self._objects = weakref.WeakValueDictionary() # type: _PathToSerializableMapping # noqa # {(parent_path, kind): cls} # (parent_path, kind) is the address of _this_ object: the parent path # plus a 'kind' string that is the name of this object. - self._type_registry = {} # type: typing.Dict[_ObjectPath, 'Type'] - self._type_known = set() # type: typing.Set['Type'] + self._type_registry = {} # type: Dict[_ObjectPath, Type[_Serializable]] + self._type_known = set() # type: Set[Type[_Serializable]] if isinstance(storage, (str, pathlib.Path)): logger.warning( @@ -570,9 +623,9 @@ def __init__(self, storage: SQLiteStorage, charm_dir: "Path", self.register_type(StoredStateData, None, StoredStateData.handle_kind) stored_handle = Handle(None, StoredStateData.handle_kind, '_stored') try: - self._stored = self.load_snapshot(stored_handle) + self._stored = typing.cast(StoredStateData, self.load_snapshot(stored_handle)) # noqa except NoSnapshotError: - self._stored = StoredStateData(self, '_stored') + self._stored = StoredStateData(self, '_stored') # noqa self._stored['event_count'] = 0 # Flag to indicate that we already presented the welcome message in a debugger breakpoint @@ -581,7 +634,7 @@ def __init__(self, storage: SQLiteStorage, charm_dir: "Path", # Parse the env var once, which may be used multiple times later debug_at = os.environ.get('JUJU_DEBUG_AT') self._juju_debug_at = (set(x.strip() for x in debug_at.split(',')) - if debug_at else set()) # type: typing.Set[str] + if debug_at else set()) # type: Set[str] def set_breakpointhook(self): """Hook into sys.breakpointhook so the builtin breakpoint() works as expected. @@ -608,7 +661,7 @@ def close(self): """Close the underlying backends.""" self._storage.close() - def _track(self, obj): + def _track(self, obj: '_Serializable'): """Track object and ensure it is the only object created using its handle path.""" if obj is self: # Framework objects don't track themselves @@ -618,7 +671,7 @@ def _track(self, obj): 'two objects claiming to be {} have been created'.format(obj.handle.path)) self._objects[obj.handle.path] = obj - def _forget(self, obj): + def _forget(self, obj: '_Serializable'): """Stop tracking the given object. See also _track.""" self._objects.pop(obj.handle.path, None) @@ -633,20 +686,20 @@ def commit(self): self.save_snapshot(self._stored) self._storage.commit() - def register_type(self, cls, parent: typing.Optional["_ParentHandle"], kind=None): + def register_type(self, cls: 'Type[_Serializable]', parent: Optional['_ParentHandle'], + kind: str = None): """Register a type to a handle.""" - if parent is not None and not isinstance(parent, Handle): - parent = parent.handle - if parent: + parent_path = None # type: Optional[str] + if isinstance(parent, Object): + parent_path = parent.handle.path + elif isinstance(parent, Handle): parent_path = parent.path - else: - parent_path = None - if not kind: - kind = cls.handle_kind + + kind = kind or cls.handle_kind # type: str self._type_registry[(parent_path, kind)] = cls self._type_known.add(cls) - def save_snapshot(self, value: "_Serializable"): + def save_snapshot(self, value: Union["StoredStateData", "EventBase"]): """Save a persistent snapshot of the provided value.""" if type(value) not in self._type_known: raise RuntimeError( @@ -666,12 +719,12 @@ def save_snapshot(self, value: "_Serializable"): self._storage.save_snapshot(value.handle.path, data) - def load_snapshot(self, handle: Handle) -> '_ObjectType': + def load_snapshot(self, handle: Handle) -> '_Serializable': """Load a persistent snapshot.""" parent_path = None if handle.parent: parent_path = handle.parent.path - cls = self._type_registry.get((parent_path, handle.kind)) + cls = self._type_registry.get((parent_path, handle.kind)) # type: Type[_Serializable] if not cls: raise NoTypeError(handle.path) data = self._storage.load_snapshot(handle.path) @@ -686,8 +739,7 @@ def drop_snapshot(self, handle: Handle): """Discard a persistent snapshot.""" self._storage.drop_snapshot(handle.path) - def observe(self, bound_event: BoundEvent, - observer: "_ObserverCallback"): + def observe(self, bound_event: BoundEvent[Any], observer: "_ObserverCallback"): """Register observer to be called when bound_event is emitted. The bound_event is generally provided as an attribute of the object that emits @@ -738,26 +790,32 @@ class SomeObject: extra_params = list(sig.parameters.values())[1:] method_name = observer.__name__ - observer = observer.__self__ + + assert isinstance(observer.__self__, Object), "can't register observers " \ + "that aren't `Object`s" + observer_obj = observer.__self__ if not sig.parameters: raise TypeError( - '{}.{} must accept event parameter'.format(type(observer).__name__, method_name)) + '{}.{} must accept event parameter'.format( + type(observer_obj).__name__, method_name)) elif any(param.default is inspect.Parameter.empty for param in extra_params): # Allow for additional optional params, since there's no reason to exclude them, but # required params will break. raise TypeError( - '{}.{} has extra required parameter'.format(type(observer).__name__, method_name)) + '{}.{} has extra required parameter'.format( + type(observer_obj).__name__, method_name)) # TODO Prevent the exact same parameters from being registered more than once. - self._observer[observer.handle.path] = observer - self._observers.append((observer.handle.path, method_name, emitter_path, event_kind)) + self._observer[observer_obj.handle.path] = observer_obj + self._observers.append((observer_obj.handle.path, + method_name, emitter_path, event_kind)) - def _next_event_key(self): + def _next_event_key(self) -> str: """Return the next event key that should be used, incrementing the internal counter.""" # Increment the count first; this means the keys will start at 1, and 0 # means no events have been emitted. - self._stored['event_count'] += 1 + self._stored['event_count'] += 1 # type: ignore #(we know event_count holds an int) return str(self._stored['event_count']) def _emit(self, event: EventBase): @@ -765,9 +823,11 @@ def _emit(self, event: EventBase): saved = False event_path = event.handle.path event_kind = event.handle.kind - parent_path = event.handle.parent.path + parent = event.handle.parent + assert isinstance(parent, Handle), "event handle must have a parent" + parent_path = parent.path # TODO Track observers by (parent_path, event_kind) rather than as a list of - # all observers. Avoiding linear search through all observers for every event + # all observers. Avoiding linear search through all observers for every event for observer_path, method_name, _parent_path, _event_kind in self._observers: if _parent_path != parent_path: continue @@ -793,7 +853,7 @@ def reemit(self): """ self._reemit() - def _reemit(self, single_event_path=None): + def _reemit(self, single_event_path: str = None): class EventContext: """Handles toggling the hook-is-running state in backends. @@ -804,18 +864,20 @@ class EventContext: is completed. """ - def __init__(self, framework, event_name): + def __init__(self, framework: Framework, event_name: str): self._event = event_name - self._backend = None + backend = None if framework.model is not None: - self._backend = framework.model._backend + backend = framework.model._backend # noqa + self._backend = backend # type: Optional[_ModelBackend] def __enter__(self): if self._backend: self._backend._hook_is_running = self._event return self - def __exit__(self, exception_type, exception, traceback): + def __exit__(self, exception_type: 'Type[Exception]', + exception: Exception, traceback: Any): if self._backend: self._backend._hook_is_running = '' @@ -836,6 +898,7 @@ def __exit__(self, exception_type, exception, traceback): self._storage.drop_notice(event_path, observer_path, method_name) continue + event = typing.cast(EventBase, event) event.deferred = False observer = self._observer.get(observer_path) if observer: @@ -873,7 +936,7 @@ def _show_debug_code_message(self): self._breakpoint_welcomed = True print(_BREAKPOINT_WELCOME_MESSAGE, file=sys.stderr, end='') - def breakpoint(self, name=None): + def breakpoint(self, name: Optional[str] = None): """Add breakpoint, optionally named, at the place where this method is called. For the breakpoint to be activated the JUJU_DEBUG_AT environment variable @@ -887,7 +950,7 @@ def breakpoint(self, name=None): """ # If given, validate the name comply with all the rules if name is not None: - if not isinstance(name, str): + if not isinstance(name, str): # pyright: reportUnnecessaryIsInstance=false raise TypeError('breakpoint names must be strings') if name in ('hook', 'all'): raise ValueError('breakpoint names "all" and "hook" are reserved') @@ -903,7 +966,7 @@ def breakpoint(self, name=None): # If we call set_trace() directly it will open the debugger *here*, so indicating # it to use our caller's frame - code_frame = inspect.currentframe().f_back + code_frame = inspect.currentframe().f_back # type: ignore pdb.Pdb().set_trace(code_frame) else: logger.warning( @@ -918,7 +981,7 @@ def remove_unreferenced_events(self): database. """ event_regex = re.compile(_event_regex) - to_remove = [] + to_remove = [] # type: List[str] for handle_path in self._storage.list_snapshots(): if event_regex.match(handle_path): notices = self._storage.notices(handle_path) @@ -932,31 +995,31 @@ def remove_unreferenced_events(self): class StoredStateData(Object): """Manager of the stored data.""" - def __init__(self, parent, attr_name): + def __init__(self, parent: Object, attr_name: str): super().__init__(parent, attr_name) - self._cache = {} - self.dirty = False + self._cache = {} # type: Dict[str, '_StorableType'] + self.dirty = False # type: bool - def __getitem__(self, key): + def __getitem__(self, key: str) -> '_StorableType': return self._cache.get(key) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: '_StorableType'): self._cache[key] = value self.dirty = True - def __contains__(self, key): + def __contains__(self, key: str): return key in self._cache - def snapshot(self): + def snapshot(self) -> Dict[str, '_StorableType']: """Return the current state.""" return self._cache - def restore(self, snapshot): + def restore(self, snapshot: Dict[str, '_StorableType']): """Restore current state to the given snapshot.""" self._cache = snapshot self.dirty = False - def on_commit(self, event): + def on_commit(self, event: EventBase) -> None: """Save changes to the storage backend.""" if self.dirty: self.framework.save_snapshot(self) @@ -965,8 +1028,16 @@ def on_commit(self, event): class BoundStoredState: """Stored state data bound to a specific Object.""" - - def __init__(self, parent, attr_name): + if TYPE_CHECKING: + # to help the type checker and IDEs: + @property + def _data(self) -> StoredStateData: # noqa + pass # pyright: reportGeneralTypeIssues=false + @property # noqa + def _attr_name(self) -> str: # noqa + pass # pyright: reportGeneralTypeIssues=false + + def __init__(self, parent: Object, attr_name: str): parent.framework.register_type(StoredStateData, parent) handle = Handle(parent, StoredStateData.handle_kind, attr_name) @@ -979,30 +1050,30 @@ def __init__(self, parent, attr_name): self.__dict__["_data"] = data self.__dict__["_attr_name"] = attr_name - parent.framework.observe(parent.framework.on.commit, self._data.on_commit) + parent.framework.observe(parent.framework.on.commit, self._data.on_commit) # type: ignore - def __getattr__(self, key): + def __getattr__(self, key: str) -> Union['_StorableType', 'StoredObject', ObjectEvents]: # "on" is the only reserved key that can't be used in the data map. if key == "on": - return self._data.on + return self._data.on # type: ignore # casting won't work for some reason if key not in self._data: raise AttributeError("attribute '{}' is not stored".format(key)) return _wrap_stored(self._data, self._data[key]) - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: '_StoredObject'): if key == "on": raise AttributeError("attribute 'on' is reserved and cannot be set") - value = _unwrap_stored(self._data, value) + unwrapped = _unwrap_stored(self._data, value) - if not isinstance(value, (type(None), int, float, str, bytes, list, dict, set)): + if not isinstance(unwrapped, (type(None), int, float, str, bytes, list, dict, set)): raise AttributeError( 'attribute {!r} cannot be a {}: must be int/float/dict/list/etc'.format( - key, type(value).__name__)) + key, type(unwrapped).__name__)) - self._data[key] = _unwrap_stored(self._data, value) + self._data[key] = unwrapped - def set_default(self, **kwargs): + def set_default(self, **kwargs: Dict[str, '_StorableType']): """Set the value of any given key if it has not already been set.""" for k, v in kwargs.items(): if k not in self._data: @@ -1035,13 +1106,14 @@ def _on_seen(self, event): """ def __init__(self): - self.parent_type = None - self.attr_name = None + self.parent_type = None # type: Optional[Type[Any]] + self.attr_name = None # type: Optional[str] - def __get__(self, parent, parent_type=None): + def __get__(self, parent: '_ObjectType', parent_type: 'Type[_ObjectType]' + ) -> Union['StoredState', BoundStoredState]: if self.parent_type is not None and self.parent_type not in parent_type.mro(): # the StoredState instance is being shared between two unrelated classes - # -> unclear what is exepcted of us -> bail out + # -> unclear what is expected of us -> bail out raise RuntimeError( 'StoredState shared by {} and {}'.format( self.parent_type.__name__, parent_type.__name__)) @@ -1076,6 +1148,10 @@ def __get__(self, parent, parent_type=None): if bound is not None: # cache the bound object to avoid the expensive lookup the next time # (don't use setattr, to keep things symmetric with the fast-path lookup above) + + # attr_name is optional at descriptor level, but we're bound now: it's + # guaranteed to be a string. We need to help the type checker: + assert isinstance(self.attr_name, str) parent.__dict__[self.attr_name] = bound return bound @@ -1084,47 +1160,49 @@ def __get__(self, parent, parent_type=None): self.__class__.__name__, parent_type.__name__)) -def _wrap_stored(parent_data, value): - t = type(value) - if t is dict: +def _wrap_stored(parent_data: StoredStateData, value: '_StorableType' + ) -> Union['StoredDict', 'StoredList', 'StoredSet', '_StorableType']: + if isinstance(value, dict): return StoredDict(parent_data, value) - if t is list: + if isinstance(value, list): return StoredList(parent_data, value) - if t is set: + if isinstance(value, set): return StoredSet(parent_data, value) return value -def _unwrap_stored(parent_data, value): - t = type(value) - if t is StoredDict or t is StoredList or t is StoredSet: - return value._under - return value +def _unwrap_stored(parent_data: StoredStateData, + value: Union['_StoredObject', '_StorableType'] + ) -> '_StorableType': + if isinstance(value, (StoredDict, StoredList, StoredSet)): + return value._under # pyright: reportPrivateUsage=false + return typing.cast('_StorableType', value) -def _wrapped_repr(obj): +def _wrapped_repr(obj: '_StoredObject') -> str: t = type(obj) - if obj._under: - return "{}.{}({!r})".format(t.__module__, t.__name__, obj._under) + if obj._under: # pyright: reportPrivateUsage=false # noqa + return "{}.{}({!r})".format( + t.__module__, t.__name__, obj._under) # type: ignore # noqa else: return "{}.{}()".format(t.__module__, t.__name__) -class StoredDict(collections.abc.MutableMapping): +class StoredDict(typing.MutableMapping[Hashable, '_StorableType']): """A dict-like object that uses the StoredState as backend.""" - def __init__(self, stored_data, under): + def __init__(self, stored_data: StoredStateData, under: Dict[Any, Any]): self._stored_data = stored_data self._under = under - def __getitem__(self, key): + def __getitem__(self, key: Hashable): return _wrap_stored(self._stored_data, self._under[key]) - def __setitem__(self, key, value): + def __setitem__(self, key: Hashable, value: Any): self._under[key] = _unwrap_stored(self._stored_data, value) self._stored_data.dirty = True - def __delitem__(self, key): + def __delitem__(self, key: Hashable): del self._under[key] self._stored_data.dirty = True @@ -1134,7 +1212,7 @@ def __iter__(self): def __len__(self): return len(self._under) - def __eq__(self, other): + def __eq__(self, other: Any): if isinstance(other, StoredDict): return self._under == other._under elif isinstance(other, collections.abc.Mapping): @@ -1142,91 +1220,91 @@ def __eq__(self, other): else: return NotImplemented - __repr__ = _wrapped_repr + __repr__ = _wrapped_repr # type: ignore -class StoredList(collections.abc.MutableSequence): +class StoredList(typing.MutableSequence['_StorableType']): """A list-like object that uses the StoredState as backend.""" - def __init__(self, stored_data, under): + def __init__(self, stored_data: StoredStateData, under: List[Any]): self._stored_data = stored_data self._under = under - def __getitem__(self, index): + def __getitem__(self, index: int): return _wrap_stored(self._stored_data, self._under[index]) - def __setitem__(self, index, value): + def __setitem__(self, index: int, value: Any): self._under[index] = _unwrap_stored(self._stored_data, value) self._stored_data.dirty = True - def __delitem__(self, index): + def __delitem__(self, index: int): del self._under[index] self._stored_data.dirty = True def __len__(self): return len(self._under) - def insert(self, index, value): + def insert(self, index: int, value: Any): """Insert value before index.""" self._under.insert(index, value) self._stored_data.dirty = True - def append(self, value): + def append(self, value: Any): """Append value to the end of the list.""" self._under.append(value) self._stored_data.dirty = True - def __eq__(self, other): + def __eq__(self, other: Any): if isinstance(other, StoredList): return self._under == other._under - elif isinstance(other, collections.abc.Sequence): + elif isinstance(other, list): return self._under == other else: return NotImplemented - def __lt__(self, other): + def __lt__(self, other: Any): if isinstance(other, StoredList): return self._under < other._under - elif isinstance(other, collections.abc.Sequence): + elif isinstance(other, list): return self._under < other else: return NotImplemented - def __le__(self, other): + def __le__(self, other: Any): if isinstance(other, StoredList): return self._under <= other._under - elif isinstance(other, collections.abc.Sequence): + elif isinstance(other, list): return self._under <= other else: return NotImplemented - def __gt__(self, other): + def __gt__(self, other: Any): if isinstance(other, StoredList): return self._under > other._under - elif isinstance(other, collections.abc.Sequence): + elif isinstance(other, list): return self._under > other else: return NotImplemented - def __ge__(self, other): + def __ge__(self, other: Any): if isinstance(other, StoredList): return self._under >= other._under - elif isinstance(other, collections.abc.Sequence): + elif isinstance(other, list): return self._under >= other else: return NotImplemented - __repr__ = _wrapped_repr + __repr__ = _wrapped_repr # type: ignore -class StoredSet(collections.abc.MutableSet): +class StoredSet(typing.MutableSet['_StorableType']): """A set-like object that uses the StoredState as backend.""" - def __init__(self, stored_data, under): + def __init__(self, stored_data: StoredStateData, under: Set[Any]): self._stored_data = stored_data self._under = under - def add(self, key): + def add(self, key: Any): """Add a key to a set. This has no effect if the key is already present. @@ -1234,7 +1312,7 @@ def add(self, key): self._under.add(key) self._stored_data.dirty = True - def discard(self, key): + def discard(self, key: Any): """Remove a key from a set if it is a member. If the key is not a member, do nothing. @@ -1242,7 +1320,7 @@ def discard(self, key): self._under.discard(key) self._stored_data.dirty = True - def __contains__(self, key): + def __contains__(self, key: Any): return key in self._under def __iter__(self): @@ -1252,7 +1330,7 @@ def __len__(self): return len(self._under) @classmethod - def _from_iterable(cls, it): + def _from_iterable(cls, it: Iterable[_T]) -> Set[_T]: """Construct an instance of the class from any iterable input. Per https://docs.python.org/3/library/collections.abc.html @@ -1262,28 +1340,28 @@ def _from_iterable(cls, it): """ return set(it) - def __le__(self, other): + def __le__(self, other: Any): if isinstance(other, StoredSet): return self._under <= other._under - elif isinstance(other, collections.abc.Set): + elif isinstance(other, collections.abc.Set): # pyright: reportUnnecessaryIsInstance=false # noqa return self._under <= other else: return NotImplemented - def __ge__(self, other): + def __ge__(self, other: Any): if isinstance(other, StoredSet): return self._under >= other._under - elif isinstance(other, collections.abc.Set): + elif isinstance(other, collections.abc.Set): # pyright: reportUnnecessaryIsInstance=false # noqa return self._under >= other else: return NotImplemented - def __eq__(self, other): + def __eq__(self, other: Any): if isinstance(other, StoredSet): return self._under == other._under - elif isinstance(other, collections.abc.Set): + elif isinstance(other, collections.abc.Set): # pyright: reportUnnecessaryIsInstance=false # noqa return self._under == other else: return NotImplemented - __repr__ = _wrapped_repr + __repr__ = _wrapped_repr # type: ignore diff --git a/pyproject.toml b/pyproject.toml index 32c623b6b..169cf1c5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,11 +32,10 @@ per-file-ignores = ["test/*:D100,D101,D102,D103,D104"] docstring-convention = "google" [tool.pyright] -exclude = ["ops/_vendor", ".git", "__pycache__", ".tox", "build", "test", - "ops/lib", "ops/_private", "dist", "*.egg_info", "venv", "ops/charm.py", - "ops/framework.py", "ops/main.py", "ops/pebble.py", - "ops/testing.py", "ops/__init__.py"] +include = ["ops/jujuversion.py", "ops/log.py", "ops/model.py", "ops/version.py", + "ops/__init__.py", "ops/framework.py"] pythonVersion = "3.5" # check no python > 3.5 features are used pythonPlatform = "All" typeCheckingMode = "strict" -reportIncompatibleMethodOverride = false \ No newline at end of file +reportIncompatibleMethodOverride = false +reportImportCycles = false \ No newline at end of file