diff --git a/ops/model.py b/ops/model.py index 45e6a48d8..147642a3c 100644 --- a/ops/model.py +++ b/ops/model.py @@ -55,6 +55,7 @@ from ops.jujuversion import JujuVersion if typing.TYPE_CHECKING: + from pebble import _LayerDict # pyright: reportMissingTypeStubs=false from typing_extensions import TypedDict _StorageDictType = Dict[str, Optional[List['Storage']]] @@ -73,8 +74,6 @@ _StatusDict = TypedDict('_StatusDict', {'status': str, 'message': str}) # the data structure we can use to initialize pebble layers with. - # todo: replace with pebble._LayerDict (a TypedDict) when pebble.py is typed - _LayerDict = Dict[str, '_LayerDict'] _Layer = Union[str, _LayerDict, pebble.Layer] # mapping from relation name to a list of relation objects @@ -1327,8 +1326,7 @@ def start(self, *service_names: str): if not service_names: raise TypeError('start expected at least 1 argument, got 0') - # fixme: remove on pebble.exec signature fix - self._pebble.start_services(service_names) # type: ignore + self._pebble.start_services(service_names) def restart(self, *service_names: str): """Restart the given service(s) by name.""" @@ -1336,8 +1334,7 @@ def restart(self, *service_names: str): raise TypeError('restart expected at least 1 argument, got 0') try: - # fixme: remove on pebble.exec signature fix - self._pebble.restart_services(service_names) # type: ignore + self._pebble.restart_services(service_names) except pebble.APIError as e: if e.code != 400: raise e @@ -1345,18 +1342,15 @@ def restart(self, *service_names: str): stop = tuple(s.name for s in self.get_services(*service_names).values( ) if s.is_running()) # type: Tuple[str, ...] if stop: - # fixme: remove on pebble.exec signature fix - self._pebble.stop_services(stop) # type: ignore - # fixme: remove on pebble.exec signature fix - self._pebble.start_services(service_names) # type: ignore + self._pebble.stop_services(stop) + self._pebble.start_services(service_names) def stop(self, *service_names: str): """Stop given service(s) by name.""" if not service_names: raise TypeError('stop expected at least 1 argument, got 0') - # fixme: remove on pebble.exec signature fix - self._pebble.stop_services(service_names) # type: ignore + self._pebble.stop_services(service_names) def add_layer(self, label: str, layer: '_Layer', *, combine: bool = False): """Dynamically add a new layer onto the Pebble configuration layers. @@ -1372,8 +1366,7 @@ def add_layer(self, label: str, layer: '_Layer', *, combine: bool = False): are combined into a single one considering the layer override rules; if the layer doesn't exist, it is added as usual. """ - # fixme: remove ignore once pebble.py is typed - self._pebble.add_layer(label, layer, combine=combine) # type: ignore + self._pebble.add_layer(label, layer, combine=combine) def get_plan(self) -> 'pebble.Plan': """Get the current effective pebble configuration.""" @@ -1386,8 +1379,7 @@ def get_services(self, *service_names: str) -> '_ServiceInfoMapping': services, otherwise return information for only the given services. """ names = service_names or None - # fixme: remove on pebble.exec signature fix - services = self._pebble.get_services(names) # type: ignore + services = self._pebble.get_services(names) return ServiceInfoMapping(services) def get_service(self, service_name: str) -> 'pebble.ServiceInfo': @@ -1414,8 +1406,7 @@ def get_checks( level: Optional check level to query for. If not specified, fetch checks with any level. """ - # fixme: remove on pebble.exec signature fix - checks = self._pebble.get_checks(names=check_names or None, level=level) # type: ignore + checks = self._pebble.get_checks(names=check_names or None, level=level) return CheckInfoMapping(checks) def get_check(self, check_name: str) -> 'pebble.CheckInfo': @@ -1451,7 +1442,7 @@ def push(self, source: Union[bytes, str, BinaryIO, TextIO], *, encoding: str = 'utf-8', - make_dirs: Optional[bool] = False, + make_dirs: bool = False, permissions: Optional[int] = None, user_id: Optional[int] = None, user: Optional[str] = None, @@ -1477,11 +1468,10 @@ def push(self, both are specified. """ self._pebble.push(str(path), source, encoding=encoding, - # fixme: remove these ignores on pebble.exec signature fix - make_dirs=make_dirs, # type: ignore - permissions=permissions, # type: ignore - user_id=user_id, user=user, # type: ignore - group_id=group_id, group=group) # type: ignore + make_dirs=make_dirs, + permissions=permissions, + user_id=user_id, user=user, + group_id=group_id, group=group) def list_files(self, path: StrOrPath, *, pattern: Optional[str] = None, itself: bool = False) -> List['pebble.FileInfo']: @@ -1499,8 +1489,7 @@ def list_files(self, path: StrOrPath, *, pattern: Optional[str] = None, directory itself, rather than its contents. """ return self._pebble.list_files(str(path), - # fixme: remove on pebble.exec signature fix - pattern=pattern, itself=itself) # type: ignore + pattern=pattern, itself=itself) def push_path(self, source_path: Union[StrOrPath, Iterable[StrOrPath]], @@ -1677,7 +1666,7 @@ def _build_fileinfo(path: StrOrPath) -> 'pebble.FileInfo': name=path.name, type=ftype, size=info.st_size, - permissions=stat.S_IMODE(info.st_mode), # type: ignore + permissions=typing.cast(int, stat.S_IMODE(info.st_mode)), # type: ignore last_modified=datetime.datetime.fromtimestamp(info.st_mtime), user_id=info.st_uid, user=pwd.getpwuid(info.st_uid).pw_name, @@ -1770,11 +1759,10 @@ def make_dir( group: Group name for directory. Group's GID must match group_id if both are specified. """ - # fixme: remove ignores on pebble.exec signature fix self._pebble.make_dir(path, make_parents=make_parents, - permissions=permissions, # type: ignore - user_id=user_id, user=user, # type: ignore - group_id=group_id, group=group) # type: ignore + permissions=permissions, + user_id=user_id, user=user, + group_id=group_id, group=group) def remove_path(self, path: str, *, recursive: bool = False): """Remove a file or directory on the remote system. @@ -1809,19 +1797,18 @@ def exec( """ return self._pebble.exec( command, - # fixme: remove ignores on pebble.py typing fix - environment=environment, # type: ignore - working_dir=working_dir, # type: ignore - timeout=timeout, # type: ignore - user_id=user_id, # type: ignore - user=user, # type: ignore - group_id=group_id, # type: ignore - group=group, # type: ignore - stdin=stdin, # type: ignore - stdout=stdout, # type: ignore - stderr=stderr, # type: ignore - encoding=encoding, # type: ignore - combine_stderr=combine_stderr, # type: ignore + environment=environment, + working_dir=working_dir, + timeout=timeout, + user_id=user_id, + user=user, + group_id=group_id, + group=group, + stdin=stdin, + stdout=stdout, + stderr=stderr, + encoding=encoding, + combine_stderr=combine_stderr, ) def send_signal(self, sig: Union[int, str], *service_names: str): @@ -1839,8 +1826,7 @@ def send_signal(self, sig: Union[int, str], *service_names: str): if not service_names: raise TypeError('send_signal expected at least 1 service name, got 0') - # fixme: remove ignore once pebble.send_signature signature is fixed - self._pebble.send_signal(sig, service_names) # type: ignore + self._pebble.send_signal(sig, service_names) class ContainerMapping(Mapping[str, Container]): diff --git a/ops/pebble.py b/ops/pebble.py index f1fb06b4f..e21ef2e76 100644 --- a/ops/pebble.py +++ b/ops/pebble.py @@ -43,22 +43,198 @@ import urllib.parse import urllib.request import warnings +from typing import ( + TYPE_CHECKING, + Any, + BinaryIO, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Sequence, + TextIO, + Tuple, + Union, +) from ops._private import yaml from ops._vendor import websocket +if TYPE_CHECKING: + from email.message import Message + + from typing_extensions import Literal, Protocol, TypedDict + + # callback types for _MultiParser header and body handlers + class _BodyHandler(Protocol): + def __call__(self, data: bytes, done: bool = False) -> None: ... # noqa + + _HeaderHandler = Callable[[bytes], None] + _StrOrBytes = Union[str, bytes] + + # tempfile.NamedTemporaryFile has an odd interface because of that + # 'name' attribute, so we need to make a Protocol for it. + class _Tempfile(Protocol): + name = '' + def write(self, data: bytes): ... # noqa + def close(self): ... # noqa + + class _FileLikeIO(Protocol[typing.AnyStr]): # That also covers TextIO and BytesIO + def read(self, __n: int = ...) -> typing.AnyStr: ... # for BinaryIO # noqa + def write(self, __s: typing.AnyStr) -> int: ... # noqa + def __enter__(self) -> typing.IO[typing.AnyStr]: ... # noqa + + class _Readable(Protocol): + def read(self, n: int = -1) -> _StrOrBytes: ... # noqa + + class _Writeable(Protocol): + # We'd need something like io.ReadableBuffer here, + # but we can't import that type + def write(self, buf: Union[bytes, str, bytearray]) -> int: ... # noqa + + _AnyStrFileLikeIO = Union[_FileLikeIO[bytes], _FileLikeIO[str]] + _TextOrBinaryIO = Union[TextIO, BinaryIO] + _IOSource = Union[str, bytes, _AnyStrFileLikeIO] + + _SystemInfoDict = TypedDict('_SystemInfoDict', {'version': str}) + _InfoDict = TypedDict('_InfoDict', + {"name": str, + "level": Optional[Union['CheckLevel', str]], + "status": Union['CheckStatus', str], + "failures": int, + "threshold": int}) + _FileInfoDict = TypedDict('_FileInfoDict', + {"path": str, + "name": str, + "size": Optional[int], + "permissions": str, + "last-modified": str, + "user-id": Optional[int], + "user": Optional[str], + "group-id": Optional[int], + "group": Optional[str], + "type": Union['FileType', str]}) + _HttpDict = TypedDict('_HttpDict', {'url': str}) + _TcpDict = TypedDict('_TcpDict', {'port': int}) + _ExecDict = TypedDict('_ExecDict', {'command': str}) + _CheckDict = TypedDict('_CheckDict', + {'override': str, + 'level': Union['CheckLevel', str], + 'period': Optional[str], + 'timeout': Optional[str], + 'http': Optional[_HttpDict], + 'tcp': Optional[_TcpDict], + 'exec': Optional[_ExecDict], + 'threshold': Optional[int]}, + total=False) + + _AuthDict = TypedDict('_AuthDict', + {'permissions': Optional[str], + 'user-id': Optional[int], + 'user': Optional[str], + 'group-id': Optional[int], + 'group': Optional[str], + 'path': Optional[str], + 'make-dirs': Optional[bool], + 'make-parents': Optional[bool], + }, total=False) + _ServiceInfoDict = TypedDict('_ServiceInfoDict', + {'startup': Union['ServiceStartup', str], + 'current': Union['ServiceStatus', str], + 'name': str}) + _ServiceDict = TypedDict('_ServiceDict', + {'summary': str, + 'description': str, + 'startup': str, + 'override': str, + 'command': str, + 'after': Sequence[str], + 'before': Sequence[str], + 'requires': Sequence[str], + 'environment': Dict[str, str], + 'user': str, + 'user-id': Optional[int], + 'group': str, + 'group-id': Optional[int], + 'on-success': str, + 'on-failure': str, + 'on-check-failure': Dict[str, Any], + 'backoff-delay': str, + 'backoff-factor': Optional[int], + 'backoff-limit': str, + }, + total=False) + + _ProgressDict = TypedDict('_ProgressDict', + {'label': str, + 'done': int, + 'total': int}) + _TaskData = Dict[str, Any] + _TaskDict = TypedDict('_TaskDict', + {'id': 'TaskID', + 'kind': str, + 'summary': str, + 'status': str, + 'log': Optional[List[str]], + 'progress': _ProgressDict, + 'spawn-time': str, + 'ready-time': str, + 'data': Optional[_TaskData]}) + _ChangeData = TypedDict('_ChangeData', {}) + _ChangeDict = TypedDict('_ChangeDict', + {'id': str, + 'kind': str, + 'summary': str, + 'status': str, + 'ready': bool, + 'spawn-time': str, + 'tasks': Optional[List[_TaskDict]], + 'err': Optional[str], + 'ready-time': Optional[str], + 'data': Optional[_ChangeData]}) + + _PlanDict = TypedDict('_PlanDict', + {'services': Dict[str, _ServiceDict], + 'checks': Dict[str, _CheckDict]}, + total=False) + + _LayerDict = TypedDict('_LayerDict', + {'summary': str, + 'description': str, + 'services': Dict[str, _ServiceDict], + 'checks': Dict[str, _CheckDict]}, + total=False) + + _Error = TypedDict('_Error', + {'kind': str, + 'message': str}) + _Item = TypedDict('_Item', + {'path': str, + 'error': _Error}) + _FilesResponse = TypedDict('_FilesResponse', + {'result': List[_Item]}) + logger = logging.getLogger(__name__) -_not_provided = object() + +class _NotProvidedFlag: + pass + + +_not_provided = _NotProvidedFlag() class _UnixSocketConnection(http.client.HTTPConnection): """Implementation of HTTPConnection that connects to a named Unix socket.""" - def __init__(self, host, timeout=_not_provided, socket_path=None): + def __init__(self, host: str, socket_path: str, + timeout: Union[_NotProvidedFlag, float] = _not_provided): if timeout is _not_provided: super().__init__(host) else: + assert isinstance(timeout, (int, float)), timeout # type guard for pyright super().__init__(host, timeout=timeout) self.socket_path = socket_path @@ -75,13 +251,14 @@ def connect(self): class _UnixSocketHandler(urllib.request.AbstractHTTPHandler): """Implementation of HTTPHandler that uses a named Unix socket.""" - def __init__(self, socket_path): + def __init__(self, socket_path: str): super().__init__() self.socket_path = socket_path - def http_open(self, req): + def http_open(self, req: urllib.request.Request): """Override http_open to use a Unix socket connection (instead of TCP).""" - return self.do_open(_UnixSocketConnection, req, socket_path=self.socket_path) + return self.do_open(_UnixSocketConnection, req, # type:ignore + socket_path=self.socket_path) # Matches yyyy-mm-ddTHH:MM:SS(.sss)ZZZ @@ -92,7 +269,7 @@ def http_open(self, req): _TIMEOFFSET_RE = re.compile(r'([-+])(\d{2}):(\d{2})') -def _parse_timestamp(s): +def _parse_timestamp(s: str): """Parse timestamp from Go-encoded JSON. This parses RFC3339 timestamps (which are a subset of ISO8601 timestamps) @@ -123,7 +300,7 @@ def _parse_timestamp(s): microsecond=microsecond, tzinfo=tz) -def _format_timeout(timeout: float): +def _format_timeout(timeout: float) -> str: """Format timeout for use in the Pebble API. The format is in seconds with a millisecond resolution and an 's' suffix, @@ -132,7 +309,7 @@ def _format_timeout(timeout: float): return '{:.3f}s'.format(timeout) -def _json_loads(s: typing.Union[str, bytes]) -> typing.Dict: +def _json_loads(s: '_StrOrBytes') -> Dict[Any, Any]: """Like json.loads(), but handle str or bytes. This is needed because an HTTP response's read() method returns bytes on @@ -143,7 +320,7 @@ def _json_loads(s: typing.Union[str, bytes]) -> typing.Dict: return json.loads(s) -def _start_thread(target, *args, **kwargs) -> threading.Thread: +def _start_thread(target: Callable[..., Any], *args: Any, **kwargs: Any) -> threading.Thread: """Helper to simplify starting a thread.""" thread = threading.Thread(target=target, args=args, kwargs=kwargs) thread.start() @@ -183,7 +360,9 @@ class PathError(Error): def __init__(self, kind: str, message: str): """This shouldn't be instantiated directly.""" self.kind = kind - self.message = message + # FIXME: pyright rightfully complains that super().message is a method + # see: https://github.com/canonical/operator/issues/777 + self.message = message # type: ignore def __str__(self): return '{} - {}'.format(self.kind, self.message) @@ -195,13 +374,14 @@ def __repr__(self): class APIError(Error): """Raised when an HTTP API error occurs talking to the Pebble server.""" - def __init__(self, body: typing.Dict, code: int, status: str, message: str): + def __init__(self, body: Dict[str, Any], code: int, status: str, message: str): """This shouldn't be instantiated directly.""" super().__init__(message) # Makes str(e) return message self.body = body self.code = code self.status = status - self.message = message + # FIXME: pyright rightfully complains that super().message is a method + self.message = message # type: ignore def __repr__(self): return 'APIError({!r}, {!r}, {!r}, {!r})'.format( @@ -261,10 +441,10 @@ class ExecError(Error): def __init__( self, - command: typing.List[str], + command: List[str], exit_code: int, - stdout: typing.Optional[typing.AnyStr], - stderr: typing.Optional[typing.AnyStr], + stdout: Optional['_StrOrBytes'], + stderr: Optional['_StrOrBytes'], ): self.command = command self.exit_code = exit_code @@ -307,7 +487,7 @@ def __init__(self, version: str): self.version = version @classmethod - def from_dict(cls, d: typing.Dict) -> 'SystemInfo': + def from_dict(cls, d: '_SystemInfoDict') -> 'SystemInfo': """Create new SystemInfo object from dict parsed from JSON.""" return cls(version=d['version']) @@ -317,13 +497,21 @@ def __repr__(self): class Warning: """Warning object.""" + if typing.TYPE_CHECKING: + _WarningDict = TypedDict('_WarningDict', + {'message': str, + 'first-added': str, + 'last-added': str, + 'last-shown': Optional[str], + 'expire-after': str, + 'repeat-after': str}) def __init__( self, message: str, first_added: datetime.datetime, last_added: datetime.datetime, - last_shown: typing.Optional[datetime.datetime], + last_shown: Optional[datetime.datetime], expire_after: str, repeat_after: str, ): @@ -335,13 +523,14 @@ def __init__( self.repeat_after = repeat_after @classmethod - def from_dict(cls, d: typing.Dict) -> 'Warning': + def from_dict(cls, d: '_WarningDict') -> 'Warning': """Create new Warning object from dict parsed from JSON.""" return cls( message=d['message'], first_added=_parse_timestamp(d['first-added']), last_added=_parse_timestamp(d['last-added']), - last_shown=_parse_timestamp(d['last-shown']) if d.get('last-shown') else None, + last_shown=(_parse_timestamp(d['last-shown']) # type: ignore + if d.get('last-shown') else None), expire_after=d['expire-after'], repeat_after=d['repeat-after'], ) @@ -371,7 +560,7 @@ def __init__( self.total = total @classmethod - def from_dict(cls, d: typing.Dict) -> 'TaskProgress': + def from_dict(cls, d: '_ProgressDict') -> 'TaskProgress': """Create new TaskProgress object from dict parsed from JSON.""" return cls( label=d['label'], @@ -403,11 +592,11 @@ def __init__( kind: str, summary: str, status: str, - log: typing.List[str], + log: List[str], progress: TaskProgress, spawn_time: datetime.datetime, - ready_time: typing.Optional[datetime.datetime], - data: typing.Dict[str, typing.Any] = None, + ready_time: Optional[datetime.datetime], + data: Optional['_TaskData'] = None, ): self.id = id self.kind = kind @@ -420,7 +609,7 @@ def __init__( self.data = data or {} @classmethod - def from_dict(cls, d: typing.Dict) -> 'Task': + def from_dict(cls, d: '_TaskDict') -> 'Task': """Create new Task object from dict parsed from JSON.""" return cls( id=TaskID(d['id']), @@ -464,12 +653,12 @@ def __init__( kind: str, summary: str, status: str, - tasks: typing.List[Task], + tasks: List[Task], ready: bool, - err: typing.Optional[str], + err: Optional[str], spawn_time: datetime.datetime, - ready_time: typing.Optional[datetime.datetime], - data: typing.Dict[str, typing.Any] = None, + ready_time: Optional[datetime.datetime], + data: Optional['_ChangeData'] = None, ): self.id = id self.kind = kind @@ -483,7 +672,7 @@ def __init__( self.data = data or {} @classmethod - def from_dict(cls, d: typing.Dict) -> 'Change': + def from_dict(cls, d: '_ChangeDict') -> 'Change': """Create new Change object from dict parsed from JSON.""" return cls( id=ChangeID(d['id']), @@ -494,7 +683,8 @@ def from_dict(cls, d: typing.Dict) -> 'Change': ready=d['ready'], err=d.get('err'), spawn_time=_parse_timestamp(d['spawn-time']), - ready_time=_parse_timestamp(d['ready-time']) if d.get('ready-time') else None, + ready_time=(_parse_timestamp(d['ready-time']) # type: ignore + if d.get('ready-time') else None), data=d.get('data') or {}, ) @@ -521,15 +711,19 @@ class Plan: """ def __init__(self, raw: str): - d = yaml.safe_load(raw) or {} + d = yaml.safe_load(raw) or {} # type: ignore + d = typing.cast('_PlanDict', d) + self._raw = raw self._services = {name: Service(name, service) - for name, service in d.get('services', {}).items()} + for name, service in d.get('services', {}).items() + } # type: Dict[str, Service] self._checks = {name: Check(name, check) - for name, check in d.get('checks', {}).items()} + for name, check in d.get('checks', {}).items() + } # type: Dict[str, Check] @property - def services(self): + def services(self) -> Dict[str, 'Service']: """This plan's services mapping (maps service name to Service). This property is currently read-only. @@ -537,24 +731,25 @@ def services(self): return self._services @property - def checks(self): + def checks(self) -> Dict[str, 'Check']: """This plan's checks mapping (maps check name to :class:`Check`). This property is currently read-only. """ return self._checks - def to_dict(self) -> typing.Dict[str, typing.Any]: + def to_dict(self) -> '_PlanDict': """Convert this plan to its dict representation.""" fields = [ ('services', {name: service.to_dict() for name, service in self._services.items()}), ('checks', {name: check.to_dict() for name, check in self._checks.items()}), ] - return {name: value for name, value in fields if value} + dct = {name: value for name, value in fields if value} + return typing.cast('_PlanDict', dct) def to_yaml(self) -> str: """Return this plan's YAML representation.""" - return yaml.safe_dump(self.to_dict()) + return yaml.safe_dump(self.to_dict()) # type: ignore __str__ = to_yaml @@ -575,25 +770,30 @@ class Layer: # This is how you do type annotations, but it is not supported by Python 3.5 # summary: str # description: str - # services: typing.Mapping[str, 'Service'] + # services: Mapping[str, 'Service'] - def __init__(self, raw: typing.Union[str, typing.Dict] = None): + def __init__(self, raw: Optional[Union[str, '_LayerDict']] = None): if isinstance(raw, str): - d = yaml.safe_load(raw) or {} + d = yaml.safe_load(raw) or {} # type: ignore # (Any 'raw' type) else: d = raw or {} - self.summary = d.get('summary', '') - self.description = d.get('description', '') + d = typing.cast('_LayerDict', d) + + self.summary = d.get('summary', '') # type: str + self.description = d.get('description', '') # type: str self.services = {name: Service(name, service) - for name, service in d.get('services', {}).items()} + for name, service in d.get('services', {}).items() + } # type: Dict[str, Service] self.checks = {name: Check(name, check) - for name, check in d.get('checks', {}).items()} + for name, check in d.get('checks', {}).items() + } # type: Dict[str, Check] def to_yaml(self) -> str: """Convert this layer to its YAML representation.""" - return yaml.safe_dump(self.to_dict()) + yamlstr = yaml.safe_dump(self.to_dict()) # type: ignore + return typing.cast(str, yamlstr) - def to_dict(self) -> typing.Dict[str, typing.Any]: + def to_dict(self) -> '_LayerDict': """Convert this layer to its dict representation.""" fields = [ ('summary', self.summary), @@ -601,7 +801,8 @@ def to_dict(self) -> typing.Dict[str, typing.Any]: ('services', {name: service.to_dict() for name, service in self.services.items()}), ('checks', {name: check.to_dict() for name, check in self.checks.items()}), ] - return {name: value for name, value in fields if value} + dct = {name: value for name, value in fields if value} + return typing.cast('_LayerDict', dct) def __repr__(self) -> str: return 'Layer({!r})'.format(self.to_dict()) @@ -612,30 +813,30 @@ def __repr__(self) -> str: class Service: """Represents a service description in a Pebble configuration layer.""" - def __init__(self, name: str, raw: typing.Dict = None): + def __init__(self, name: str, raw: Optional['_ServiceDict'] = None): self.name = name - raw = raw or {} - self.summary = raw.get('summary', '') - self.description = raw.get('description', '') - self.startup = raw.get('startup', '') - self.override = raw.get('override', '') - self.command = raw.get('command', '') - self.after = list(raw.get('after', [])) - self.before = list(raw.get('before', [])) - self.requires = list(raw.get('requires', [])) - self.environment = dict(raw.get('environment', {})) - self.user = raw.get('user', '') - self.user_id = raw.get('user-id') - self.group = raw.get('group', '') - self.group_id = raw.get('group-id') - self.on_success = raw.get('on-success', '') - self.on_failure = raw.get('on-failure', '') - self.on_check_failure = dict(raw.get('on-check-failure', {})) - self.backoff_delay = raw.get('backoff-delay', '') - self.backoff_factor = raw.get('backoff-factor') - self.backoff_limit = raw.get('backoff-limit', '') - - def to_dict(self) -> typing.Dict[str, typing.Any]: + dct = raw or {} # type: _ServiceDict + self.summary = dct.get('summary', '') + self.description = dct.get('description', '') + self.startup = dct.get('startup', '') + self.override = dct.get('override', '') + self.command = dct.get('command', '') + self.after = list(dct.get('after', [])) + self.before = list(dct.get('before', [])) + self.requires = list(dct.get('requires', [])) + self.environment = dict(dct.get('environment', {})) + self.user = dct.get('user', '') + self.user_id = dct.get('user-id') + self.group = dct.get('group', '') + self.group_id = dct.get('group-id') + self.on_success = dct.get('on-success', '') + self.on_failure = dct.get('on-failure', '') + self.on_check_failure = dict(dct.get('on-check-failure', {})) + self.backoff_delay = dct.get('backoff-delay', '') + self.backoff_factor = dct.get('backoff-factor') + self.backoff_limit = dct.get('backoff-limit', '') + + def to_dict(self) -> '_ServiceDict': """Convert this service object to its dict representation.""" fields = [ ('summary', self.summary), @@ -658,7 +859,8 @@ def to_dict(self) -> typing.Dict[str, typing.Any]: ('backoff-factor', self.backoff_factor), ('backoff-limit', self.backoff_limit), ] - return {name: value for name, value in fields if value} + dct = {name: value for name, value in fields if value} + return typing.cast('_ServiceDict', dct) def _merge(self, other: 'Service'): """Merges this service object with another service definition. @@ -679,7 +881,7 @@ def _merge(self, other: 'Service'): def __repr__(self) -> str: return 'Service({!r})'.format(self.to_dict()) - def __eq__(self, other: typing.Union[typing.Dict, 'Service']) -> bool: + def __eq__(self, other: Union['_ServiceDict', 'Service']) -> bool: """Compare this service description to another.""" if isinstance(other, dict): return self.to_dict() == other @@ -712,8 +914,8 @@ class ServiceInfo: def __init__( self, name: str, - startup: typing.Union[ServiceStartup, str], - current: typing.Union[ServiceStatus, str], + startup: Union[ServiceStartup, str], + current: Union[ServiceStatus, str], ): self.name = name self.startup = startup @@ -724,7 +926,7 @@ def is_running(self) -> bool: return self.current == ServiceStatus.ACTIVE @classmethod - def from_dict(cls, d: typing.Dict) -> 'ServiceInfo': + def from_dict(cls, d: '_ServiceInfoDict') -> 'ServiceInfo': """Create new ServiceInfo object from dict parsed from JSON.""" try: startup = ServiceStartup(d['startup']) @@ -751,38 +953,40 @@ def __repr__(self): class Check: """Represents a check in a Pebble configuration layer.""" - def __init__(self, name: str, raw: typing.Dict = None): + def __init__(self, name: str, raw: Optional['_CheckDict'] = None): self.name = name - raw = raw or {} - self.override = raw.get('override', '') + dct = raw or {} # type: _CheckDict + self.override = dct.get('override', '') # type: str try: - self.level = CheckLevel(raw.get('level', '')) + level = CheckLevel(dct.get('level', '')) # type: Union[CheckLevel, str] except ValueError: - self.level = raw.get('level') - self.period = raw.get('period', '') - self.timeout = raw.get('timeout', '') - self.threshold = raw.get('threshold') + level = dct.get('level', '') + self.level = level + self.period = dct.get('period', '') # type: Optional[str] + self.timeout = dct.get('timeout', '') # type: Optional[str] + self.threshold = dct.get('threshold') # type: Optional[int] - http = raw.get('http') + http = dct.get('http') if http is not None: http = copy.deepcopy(http) - self.http = http + self.http = http # type: Optional[_HttpDict] - tcp = raw.get('tcp') + tcp = dct.get('tcp') if tcp is not None: tcp = copy.deepcopy(tcp) - self.tcp = tcp + self.tcp = tcp # type: Optional[_TcpDict] - exec_ = raw.get('exec') + exec_ = dct.get('exec') if exec_ is not None: exec_ = copy.deepcopy(exec_) - self.exec = exec_ + self.exec = exec_ # type: Optional[_ExecDict] - def to_dict(self) -> typing.Dict[str, typing.Any]: + def to_dict(self) -> '_CheckDict': """Convert this check object to its dict representation.""" + level = self.level.value if isinstance(self.level, CheckLevel) else self.level # type: str fields = [ ('override', self.override), - ('level', self.level.value), + ('level', level), ('period', self.period), ('timeout', self.timeout), ('threshold', self.threshold), @@ -790,14 +994,15 @@ def to_dict(self) -> typing.Dict[str, typing.Any]: ('tcp', self.tcp), ('exec', self.exec), ] - return {name: value for name, value in fields if value} + dct = {name: value for name, value in fields if value} + return typing.cast('_CheckDict', dct) def __repr__(self) -> str: return 'Check({!r})'.format(self.to_dict()) - def __eq__(self, other: typing.Union[typing.Dict, 'Check']) -> bool: + def __eq__(self, other: Union['_CheckDict', 'Check']) -> bool: """Compare this check configuration to another.""" - if isinstance(other, dict): + if isinstance(other, dict): # pyright: reportUnnecessaryComparison=false return self.to_dict() == other elif isinstance(other, Check): return self.to_dict() == other.to_dict() @@ -839,14 +1044,14 @@ def __init__( self, path: str, name: str, - type: typing.Union['FileType', str], - size: typing.Optional[int], + type: Union['FileType', str], + size: Optional[int], permissions: int, last_modified: datetime.datetime, - user_id: typing.Optional[int], - user: typing.Optional[str], - group_id: typing.Optional[int], - group: typing.Optional[str], + user_id: Optional[int], + user: Optional[str], + group_id: Optional[int], + group: Optional[str], ): self.path = path self.name = name @@ -860,7 +1065,7 @@ def __init__( self.group = group @classmethod - def from_dict(cls, d: typing.Dict) -> 'FileInfo': + def from_dict(cls, d: '_FileInfoDict') -> 'FileInfo': """Create new FileInfo object from dict parsed from JSON.""" try: file_type = FileType(d['type']) @@ -916,8 +1121,8 @@ class CheckInfo: def __init__( self, name: str, - level: typing.Optional[typing.Union[CheckLevel, str]], - status: typing.Union[CheckStatus, str], + level: Optional[Union[CheckLevel, str]], + status: Union[CheckStatus, str], failures: int = 0, threshold: int = 0, ): @@ -928,7 +1133,7 @@ def __init__( self.threshold = threshold @classmethod - def from_dict(cls, d: typing.Dict) -> 'CheckInfo': + def from_dict(cls, d: '_InfoDict') -> 'CheckInfo': """Create new :class:`CheckInfo` object from dict parsed from JSON.""" try: level = CheckLevel(d.get('level', '')) @@ -985,20 +1190,20 @@ class ExecProcess: def __init__( self, - stdin: typing.Optional[typing.Union[typing.TextIO, typing.BinaryIO]], - stdout: typing.Optional[typing.Union[typing.TextIO, typing.BinaryIO]], - stderr: typing.Optional[typing.Union[typing.TextIO, typing.BinaryIO]], + stdin: Optional['_Readable'], + stdout: Optional['_Writeable'], + stderr: Optional['_Writeable'], client: 'Client', - timeout: typing.Optional[float], + timeout: Optional[float], control_ws: websocket.WebSocket, stdio_ws: websocket.WebSocket, - stderr_ws: websocket.WebSocket, - command: typing.List[str], - encoding: typing.Optional[str], + stderr_ws: Optional[websocket.WebSocket], + command: List[str], + encoding: Optional[str], change_id: ChangeID, - cancel_stdin: typing.Callable[[], None], - cancel_reader: typing.Optional[int], - threads: typing.List[threading.Thread], + cancel_stdin: Optional[Callable[[], None]], + cancel_reader: Optional[int], + threads: List[threading.Thread], ): self.stdin = stdin self.stdout = stdout @@ -1035,7 +1240,7 @@ def wait(self): if exit_code != 0: raise ExecError(self._command, exit_code, None, None) - def _wait(self): + def _wait(self) -> int: self._waited = True timeout = self._timeout if timeout is not None: @@ -1070,7 +1275,7 @@ def _wait(self): exit_code = change.tasks[0].data.get('exit-code', -1) return exit_code - def wait_output(self) -> typing.Tuple[typing.AnyStr, typing.AnyStr]: + def wait_output(self) -> Tuple['_StrOrBytes', Optional['_StrOrBytes']]: """Wait for the process to finish and return tuple of (stdout, stderr). If a timeout was specified to the :meth:`Client.exec` call, this waits @@ -1095,16 +1300,16 @@ def wait_output(self) -> typing.Tuple[typing.AnyStr, typing.AnyStr]: t = _start_thread(shutil.copyfileobj, self.stderr, err) self._threads.append(t) - exit_code = self._wait() + exit_code = self._wait() # type: int - out_value = out.getvalue() - err_value = err.getvalue() if err is not None else None + out_value = out.getvalue() # type: '_StrOrBytes' + err_value = err.getvalue() if err is not None else None # type: Optional['_StrOrBytes'] if exit_code != 0: raise ExecError(self._command, exit_code, out_value, err_value) return (out_value, err_value) - def send_signal(self, sig: typing.Union[int, str]): + def send_signal(self, sig: Union[int, str]): """Send the given signal to the running process. Args: @@ -1118,13 +1323,13 @@ def send_signal(self, sig: typing.Union[int, str]): 'signal': {'name': sig}, } msg = json.dumps(payload, sort_keys=True) - self._control_ws.send(msg) + self._control_ws.send(msg) # type: ignore -def _has_fileno(f): +def _has_fileno(f: Any) -> bool: """Return True if the file-like object has a valid fileno() method.""" try: - f.fileno() + f.fileno() # type: ignore # noqa return True except Exception: # Some types define a fileno method that raises io.UnsupportedOperation, @@ -1132,7 +1337,11 @@ def _has_fileno(f): return False -def _reader_to_websocket(reader, ws, encoding, cancel_reader=None, bufsize=16 * 1024): +def _reader_to_websocket(reader: '_WebsocketReader', + ws: websocket.WebSocket, + encoding: str, + cancel_reader: Optional[int] = None, + bufsize: int = 16 * 1024): """Read reader through to EOF and send each chunk read to the websocket.""" while True: if cancel_reader is not None: @@ -1146,15 +1355,16 @@ def _reader_to_websocket(reader, ws, encoding, cancel_reader=None, bufsize=16 * break if isinstance(chunk, str): chunk = chunk.encode(encoding) - ws.send_binary(chunk) + ws.send_binary(chunk) # type: ignore - ws.send('{"command":"end"}') # Send "end" command as TEXT frame to signal EOF + ws.send('{"command":"end"}') # type: ignore # Send "end" command as TEXT frame to signal EOF -def _websocket_to_writer(ws, writer, encoding): +def _websocket_to_writer(ws: websocket.WebSocket, writer: '_WebsocketWriter', + encoding: str): """Receive messages from websocket (until end signal) and write to writer.""" while True: - chunk = ws.recv() + chunk = ws.recv() # type: '_StrOrBytes' if isinstance(chunk, str): try: @@ -1179,45 +1389,45 @@ def _websocket_to_writer(ws, writer, encoding): class _WebsocketWriter(io.BufferedIOBase): """A writable file-like object that sends what's written to it to a websocket.""" - def __init__(self, ws): + def __init__(self, ws: websocket.WebSocket): self.ws = ws def writable(self): """Denote this file-like object as writable.""" return True - def write(self, chunk): + def write(self, chunk: '_StrOrBytes') -> int: """Write chunk to the websocket.""" if not isinstance(chunk, bytes): raise TypeError('value to write must be bytes, not {}'.format(type(chunk).__name__)) - self.ws.send_binary(chunk) + self.ws.send_binary(chunk) # type: ignore return len(chunk) def close(self): """Send end-of-file message to websocket.""" - self.ws.send('{"command":"end"}') + self.ws.send('{"command":"end"}') # type: ignore class _WebsocketReader(io.BufferedIOBase): """A readable file-like object whose reads come from a websocket.""" - def __init__(self, ws): + def __init__(self, ws: websocket.WebSocket): self.ws = ws self.remaining = b'' self.eof = False - def readable(self): + def readable(self) -> bool: """Denote this file-like object as readable.""" return True - def read(self, n=-1): + def read(self, n: int = -1) -> '_StrOrBytes': """Read up to n bytes from the websocket (or one message if n<0).""" if self.eof: # Calling read() multiple times after EOF should still return EOF return b'' while not self.remaining: - chunk = self.ws.recv() + chunk = self.ws.recv() # type: '_StrOrBytes' if isinstance(chunk, str): try: @@ -1239,11 +1449,11 @@ def read(self, n=-1): if n < 0: n = len(self.remaining) - result = self.remaining[:n] + result = self.remaining[:n] # type: '_StrOrBytes' self.remaining = self.remaining[n:] return result - def read1(self, n=-1): + def read1(self, n: int = -1) -> '_StrOrBytes': """An alias for read.""" return self.read(n) @@ -1253,15 +1463,19 @@ class Client: _chunk_size = 8192 - def __init__(self, socket_path=None, opener=None, base_url='http://localhost', timeout=5.0): + def __init__(self, socket_path: str, + opener: Optional[urllib.request.OpenerDirector] = None, + base_url: str = 'http://localhost', + timeout: float = 5.0): """Initialize a client instance. Defaults to using a Unix socket at socket_path (which must be specified unless a custom opener is provided). """ + if not isinstance(socket_path, str): + raise TypeError('`socket_path` should be a string, ' + 'not: {}'.format(type(socket_path))) if opener is None: - if socket_path is None: - raise ValueError('no socket path provided') opener = self._get_default_opener(socket_path) self.socket_path = socket_path self.opener = opener @@ -1269,7 +1483,7 @@ def __init__(self, socket_path=None, opener=None, base_url='http://localhost', t self.timeout = timeout @classmethod - def _get_default_opener(cls, socket_path): + def _get_default_opener(cls, socket_path: str) -> urllib.request.OpenerDirector: """Build the default opener to use for requests (HTTP over Unix socket).""" opener = urllib.request.OpenerDirector() opener.add_handler(_UnixSocketHandler(socket_path)) @@ -1278,9 +1492,13 @@ def _get_default_opener(cls, socket_path): opener.add_handler(urllib.request.HTTPErrorProcessor()) return opener - def _request( - self, method: str, path: str, query: typing.Dict = None, body: typing.Dict = None, - ) -> typing.Dict: + # we need to cast the return type depending on the request params + def _request(self, + method: str, + path: str, + query: Optional[Dict[str, Any]] = None, + body: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Make a JSON request to the Pebble server with the given HTTP method and path. If query dict is provided, it is encoded and appended as a query string @@ -1296,10 +1514,12 @@ def _request( response = self._request_raw(method, path, query, headers, data) self._ensure_content_type(response.headers, 'application/json') - return _json_loads(response.read()) + raw_resp = _json_loads(response.read()) # type: Dict[str, Any] + return raw_resp @staticmethod - def _ensure_content_type(headers, expected): + def _ensure_content_type(headers: 'Message', + expected: 'Literal["multipart/form-data", "application/json"]'): """Parse Content-Type header from headers and ensure it's equal to expected. Return a dict of any options in the header, e.g., {'boundary': ...}. @@ -1310,8 +1530,10 @@ def _ensure_content_type(headers, expected): return options def _request_raw( - self, method: str, path: str, query: typing.Dict = None, headers: typing.Dict = None, - data: bytes = None, + self, method: str, path: str, + query: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + data: Optional[Union[bytes, Generator[bytes, Any, Any]]] = None, ) -> http.client.HTTPResponse: """Make a request to the Pebble server; return the raw HTTPResponse object.""" url = self.base_url + path @@ -1333,11 +1555,11 @@ def _request_raw( code = e.code status = e.reason try: - body = _json_loads(e.read()) - message = body['result']['message'] + body = _json_loads(e.read()) # type: Dict[str, Any] + message = body['result']['message'] # type: str except (IOError, ValueError, KeyError) as e2: # Will only happen on read error or if Pebble sends invalid JSON. - body = {} + body = {} # type: Dict[str, Any] message = '{} - {}'.format(type(e2).__name__, e2) raise APIError(body, code, status, message) except urllib.error.URLError as e: @@ -1350,7 +1572,7 @@ def get_system_info(self) -> SystemInfo: resp = self._request('GET', '/v1/system-info') return SystemInfo.from_dict(resp['result']) - def get_warnings(self, select: WarningState = WarningState.PENDING) -> typing.List[Warning]: + def get_warnings(self, select: WarningState = WarningState.PENDING) -> List[Warning]: """Get list of warnings in given state (pending or all).""" query = {'select': select.value} resp = self._request('GET', '/v1/warnings', query) @@ -1363,10 +1585,10 @@ def ack_warnings(self, timestamp: datetime.datetime) -> int: return resp['result'] def get_changes( - self, select: ChangeState = ChangeState.IN_PROGRESS, service: str = None, - ) -> typing.List[Change]: + self, select: ChangeState = ChangeState.IN_PROGRESS, service: Optional[str] = None, + ) -> List[Change]: """Get list of changes in given state, filter by service name if given.""" - query = {'select': select.value} + query = {'select': select.value} # type: Dict[str, Union[str, int]] if service is not None: query['for'] = service resp = self._request('GET', '/v1/changes', query) @@ -1418,7 +1640,7 @@ def replan_services(self, timeout: float = 30.0, delay: float = 0.1) -> ChangeID return self._services_action('replan', [], timeout, delay) def start_services( - self, services: typing.List[str], timeout: float = 30.0, delay: float = 0.1, + self, services: Iterable[str], timeout: float = 30.0, delay: float = 0.1, ) -> ChangeID: """Start services by name and wait (poll) for them to be started. @@ -1438,7 +1660,7 @@ def start_services( return self._services_action('start', services, timeout, delay) def stop_services( - self, services: typing.List[str], timeout: float = 30.0, delay: float = 0.1, + self, services: Iterable[str], timeout: float = 30.0, delay: float = 0.1, ) -> ChangeID: """Stop services by name and wait (poll) for them to be started. @@ -1458,7 +1680,7 @@ def stop_services( return self._services_action('stop', services, timeout, delay) def restart_services( - self, services: typing.List[str], timeout: float = 30.0, delay: float = 0.1, + self, services: Iterable[str], timeout: float = 30.0, delay: float = 0.1, ) -> ChangeID: """Restart services by name and wait (poll) for them to be started. @@ -1478,11 +1700,15 @@ def restart_services( return self._services_action('restart', services, timeout, delay) def _services_action( - self, action: str, services: typing.Iterable[str], timeout: float, delay: float, + self, action: str, services: Iterable[str], timeout: Optional[float], + delay: float, ) -> ChangeID: - if not isinstance(services, (list, tuple)): - raise TypeError('services must be a list of str, not {}'.format( - type(services).__name__)) + if isinstance(services, (str, bytes)) or not hasattr(services, '__iter__'): + raise TypeError( + 'services must be of type Iterable[str], not {}'.format( + type(services).__name__)) + + services = list(services) for s in services: if not isinstance(s, str): raise TypeError('service names must be str, not {}'.format(type(s).__name__)) @@ -1497,7 +1723,9 @@ def _services_action( return change_id def wait_change( - self, change_id: ChangeID, timeout: float = 30.0, delay: float = 0.1, + self, change_id: ChangeID, + timeout: Optional[float] = 30.0, + delay: float = 0.1, ) -> Change: """Wait for the given change to be ready. @@ -1508,7 +1736,7 @@ def wait_change( Args: change_id: Change ID of change to wait for. timeout: Maximum time in seconds to wait for the change to be - ready. May be None, in which case wait_change never times out. + ready. It may be None, in which case wait_change never times out. delay: If polling, this is the delay in seconds between attempts. Returns: @@ -1523,9 +1751,9 @@ def wait_change( # Pebble server doesn't support wait endpoint, fall back to polling return self._wait_change_using_polling(change_id, timeout, delay) - def _wait_change_using_wait(self, change_id, timeout): + def _wait_change_using_wait(self, change_id: ChangeID, timeout: Optional[float]): """Wait for a change to be ready using the wait-change API.""" - deadline = time.time() + timeout if timeout is not None else None + deadline = time.time() + timeout if timeout is not None else 0 # Hit the wait endpoint every Client.timeout-1 seconds to avoid long # requests (the -1 is to ensure it wakes up before the socket timeout) @@ -1547,7 +1775,7 @@ def _wait_change_using_wait(self, change_id, timeout): raise TimeoutError('timed out waiting for change {} ({} seconds)'.format( change_id, timeout)) - def _wait_change(self, change_id: ChangeID, timeout: float = None) -> Change: + def _wait_change(self, change_id: ChangeID, timeout: Optional[float] = None) -> Change: """Call the wait-change API endpoint directly.""" query = {} if timeout is not None: @@ -1565,9 +1793,10 @@ def _wait_change(self, change_id: ChangeID, timeout: float = None) -> Change: return Change.from_dict(resp['result']) - def _wait_change_using_polling(self, change_id, timeout, delay): + def _wait_change_using_polling(self, change_id: ChangeID, timeout: Optional[float], + delay: float): """Wait for a change to be ready by polling the get-change API.""" - deadline = time.time() + timeout if timeout is not None else None + deadline = time.time() + timeout if timeout is not None else 0 while timeout is None or time.time() < deadline: change = self.get_change(change_id) @@ -1580,7 +1809,8 @@ def _wait_change_using_polling(self, change_id, timeout, delay): change_id, timeout)) def add_layer( - self, label: str, layer: typing.Union[str, dict, Layer], *, combine: bool = False): + self, label: str, layer: Union[str, '_LayerDict', Layer], *, + combine: bool = False): """Dynamically add a new layer onto the Pebble configuration layers. If combine is False (the default), append the new layer as the top @@ -1615,7 +1845,7 @@ def get_plan(self) -> Plan: resp = self._request('GET', '/v1/plan', {'format': 'yaml'}) return Plan(resp['result']) - def get_services(self, names: typing.List[str] = None) -> typing.List[ServiceInfo]: + def get_services(self, names: Optional[Iterable[str]] = None) -> List[ServiceInfo]: """Get the service status for the configured services. If names is specified, only fetch the service status for the services @@ -1630,8 +1860,7 @@ def get_services(self, names: typing.List[str] = None) -> typing.List[ServiceInf def pull(self, path: str, *, - encoding: typing.Optional[str] = 'utf-8') -> typing.Union[typing.BinaryIO, - typing.TextIO]: + encoding: Optional[str] = 'utf-8') -> Union[BinaryIO, TextIO]: """Read a file's content from the remote system. Args: @@ -1687,12 +1916,13 @@ def pull(self, # removing opened files, and so we use the tempfile lib's # helper class to auto-delete on close/gc for us. if os.name != 'posix' or sys.platform == 'cygwin': - return tempfile._TemporaryFileWrapper(f, f.name, delete=True) + return tempfile._TemporaryFileWrapper( # type: ignore + f, f.name, delete=True) # type: ignore parser.remove_files() return f @staticmethod - def _raise_on_path_error(resp, path): + def _raise_on_path_error(resp: '_FilesResponse', path: str): result = resp['result'] or [] # in case it's null instead of [] paths = {item['path']: item for item in result} if path not in paths: @@ -1702,9 +1932,13 @@ def _raise_on_path_error(resp, path): raise PathError(error['kind'], error['message']) def push( - self, path: str, source: typing.Union[bytes, str, typing.BinaryIO, typing.TextIO], *, - encoding: str = 'utf-8', make_dirs: bool = False, permissions: int = None, - user_id: int = None, user: str = None, group_id: int = None, group: str = None): + self, path: str, source: '_IOSource', *, + encoding: str = 'utf-8', make_dirs: bool = False, + permissions: Optional[int] = None, + user_id: Optional[int] = None, + user: Optional[str] = None, + group_id: Optional[int] = None, + group: Optional[str] = None): """Write content to a given file path on the remote system. Args: @@ -1742,11 +1976,16 @@ def push( response = self._request_raw('POST', '/v1/files', None, headers, data) self._ensure_content_type(response.headers, 'application/json') resp = _json_loads(response.read()) - self._raise_on_path_error(resp, path) + # we need to cast the Dict[Any, Any] to _FilesResponse + self._raise_on_path_error(typing.cast('_FilesResponse', resp), path) @staticmethod - def _make_auth_dict(permissions, user_id, user, group_id, group) -> typing.Dict: - d = {} + def _make_auth_dict(permissions: Optional[int], + user_id: Optional[int], + user: Optional[str], + group_id: Optional[int], + group: Optional[str]) -> '_AuthDict': + d = {} # type: _AuthDict if permissions is not None: d['permissions'] = format(permissions, '03o') if user_id is not None: @@ -1759,20 +1998,21 @@ def _make_auth_dict(permissions, user_id, user, group_id, group) -> typing.Dict: d['group'] = group return d - def _encode_multipart(self, metadata, path, source, encoding): + def _encode_multipart(self, metadata: Dict[str, Any], path: str, + source: '_IOSource', encoding: str): # Python's stdlib mime/multipart handling is screwy and doesn't handle # binary properly, so roll our own. - if isinstance(source, str): - source = io.StringIO(source) + source_io = io.StringIO(source) # type: _AnyStrFileLikeIO elif isinstance(source, bytes): - source = io.BytesIO(source) - + source_io = io.BytesIO(source) # type: _AnyStrFileLikeIO + else: + source_io = source # type: _AnyStrFileLikeIO boundary = binascii.hexlify(os.urandom(16)) path_escaped = path.replace('"', '\\"').encode('utf-8') # NOQA: test_quote_backslashes content_type = 'multipart/form-data; boundary="' + boundary.decode('utf-8') + '"' - def generator(): + def generator() -> Generator[bytes, None, None]: yield b''.join([ b'--', boundary, b'\r\n', b'Content-Type: application/json\r\n', @@ -1786,12 +2026,12 @@ def generator(): b'\r\n', ]) - content = source.read(self._chunk_size) + content = source_io.read(self._chunk_size) # type: '_StrOrBytes' while content: if isinstance(content, str): content = content.encode(encoding) yield content - content = source.read(self._chunk_size) + content = source_io.read(self._chunk_size) yield b''.join([ b'\r\n', @@ -1800,8 +2040,8 @@ def generator(): return generator(), content_type - def list_files(self, path: str, *, pattern: str = None, - itself: bool = False) -> typing.List[FileInfo]: + def list_files(self, path: str, *, pattern: Optional[str] = None, + itself: bool = False) -> List[FileInfo]: """Return list of directory entries from given path on remote system. Despite the name, this method returns a list of files *and* @@ -1828,8 +2068,12 @@ def list_files(self, path: str, *, pattern: str = None, return [FileInfo.from_dict(d) for d in result] def make_dir( - self, path: str, *, make_parents: bool = False, permissions: int = None, - user_id: int = None, user: str = None, group_id: int = None, group: str = None): + self, path: str, *, make_parents: bool = False, + permissions: Optional[int] = None, + user_id: Optional[int] = None, + user: Optional[str] = None, + group_id: Optional[int] = None, + group: Optional[str] = None): """Create a directory on the remote system with the given attributes. Args: @@ -1853,7 +2097,7 @@ def make_dir( 'dirs': [info], } resp = self._request('POST', '/v1/files', None, body) - self._raise_on_path_error(resp, path) + self._raise_on_path_error(typing.cast('_FilesResponse', resp), path) def remove_path(self, path: str, *, recursive: bool = False): """Remove a file or directory on the remote system. @@ -1866,7 +2110,7 @@ def remove_path(self, path: str, *, recursive: bool = False): to `rm -rf ` """ - info = {'path': path} + info = {'path': path} # type: Dict[str, Any] if recursive: info['recursive'] = True body = { @@ -1874,23 +2118,23 @@ def remove_path(self, path: str, *, recursive: bool = False): 'paths': [info], } resp = self._request('POST', '/v1/files', None, body) - self._raise_on_path_error(resp, path) + self._raise_on_path_error(typing.cast('_FilesResponse', resp), path) def exec( self, - command: typing.List[str], + command: List[str], *, - environment: typing.Dict[str, str] = None, - working_dir: str = None, - timeout: float = None, - user_id: int = None, - user: str = None, - group_id: int = None, - group: str = None, - stdin: typing.Union[str, bytes, typing.TextIO, typing.BinaryIO] = None, - stdout: typing.Union[typing.TextIO, typing.BinaryIO] = None, - stderr: typing.Union[typing.TextIO, typing.BinaryIO] = None, - encoding: str = 'utf-8', + environment: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + timeout: Optional[float] = None, + user_id: Optional[int] = None, + user: Optional[str] = None, + group_id: Optional[int] = None, + group: Optional[str] = None, + stdin: Optional['_IOSource'] = None, + stdout: Optional['_TextOrBinaryIO'] = None, + stderr: Optional['_TextOrBinaryIO'] = None, + encoding: Optional[str] = 'utf-8', combine_stderr: bool = False ) -> ExecProcess: r"""Execute the given command on the remote system. @@ -2038,7 +2282,7 @@ def exec( change_id = resp['change'] task_id = resp['result']['task-id'] - stderr_ws = None + stderr_ws = None # type: Optional[websocket.WebSocket] try: control_ws = self._connect_websocket(task_id, 'control') stdio_ws = self._connect_websocket(task_id, 'stdio') @@ -2052,9 +2296,9 @@ def exec( raise ChangeError(change.err, change) raise ConnectionError('unexpected error connecting to websockets: {}'.format(e)) - cancel_stdin = None - cancel_reader = None - threads = [] + cancel_stdin = None # type: Optional[Callable[[], None]] + cancel_reader = None # type: Optional[int] + threads = [] # type: List[threading.Thread] if stdin is not None: if _has_fileno(stdin): @@ -2066,9 +2310,10 @@ def exec( # to cancel_writer it'll trigger the select and end the thread. cancel_reader, cancel_writer = os.pipe() - def cancel_stdin(): + def _cancel_stdin(): os.write(cancel_writer, b'x') # doesn't matter what we write os.close(cancel_writer) + cancel_stdin = _cancel_stdin t = _start_thread(_reader_to_websocket, stdin, stdio_ws, encoding, cancel_reader) threads.append(t) @@ -2076,7 +2321,8 @@ def cancel_stdin(): else: process_stdin = _WebsocketWriter(stdio_ws) if encoding is not None: - process_stdin = io.TextIOWrapper(process_stdin, encoding=encoding, newline='') + process_stdin = io.TextIOWrapper( + process_stdin, encoding=encoding, newline='') # type: ignore if stdout is not None: t = _start_thread(_websocket_to_writer, stdio_ws, stdout, encoding) @@ -2085,7 +2331,8 @@ def cancel_stdin(): else: process_stdout = _WebsocketReader(stdio_ws) if encoding is not None: - process_stdout = io.TextIOWrapper(process_stdout, encoding=encoding, newline='') + process_stdout = io.TextIOWrapper( + process_stdout, encoding=encoding, newline='') # type: ignore process_stderr = None if not combine_stderr: @@ -2093,15 +2340,16 @@ def cancel_stdin(): t = _start_thread(_websocket_to_writer, stderr_ws, stderr, encoding) threads.append(t) else: - process_stderr = _WebsocketReader(stderr_ws) + ws = typing.cast(websocket.WebSocket, stderr_ws) + process_stderr = _WebsocketReader(ws) if encoding is not None: process_stderr = io.TextIOWrapper( - process_stderr, encoding=encoding, newline='') + process_stderr, encoding=encoding, newline='') # type: ignore process = ExecProcess( stdin=process_stdin, - stdout=process_stdout, - stderr=process_stderr, + stdout=process_stdout, # type: ignore # doesn't like _Writeable + stderr=process_stderr, # type: ignore # doesn't like _Writeable client=self, timeout=timeout, stdio_ws=stdio_ws, @@ -2121,7 +2369,7 @@ def _connect_websocket(self, task_id: str, websocket_id: str) -> websocket.WebSo sock.connect(self.socket_path) url = self._websocket_url(task_id, websocket_id) ws = websocket.WebSocket(skip_utf8_validation=True) - ws.connect(url, socket=sock) + ws.connect(url, socket=sock) # type: ignore return ws def _websocket_url(self, task_id: str, websocket_id: str) -> str: @@ -2129,7 +2377,7 @@ def _websocket_url(self, task_id: str, websocket_id: str) -> str: url = '{}/v1/tasks/{}/websocket/{}'.format(base_url, task_id, websocket_id) return url - def send_signal(self, sig: typing.Union[int, str], services: typing.List[str]): + def send_signal(self, sig: Union[int, str], services: Iterable[str]): """Send the given signal to the list of services named. Args: @@ -2141,11 +2389,11 @@ def send_signal(self, sig: typing.Union[int, str], services: typing.List[str]): APIError: If any of the services are not in the plan or are not currently running. """ - if not isinstance(services, (list, tuple)): - raise TypeError('services must be a list of str, not {}'.format( - type(services).__name__)) + if isinstance(services, (str, bytes)) or not hasattr(services, '__iter__'): + raise TypeError('services must be of type Iterable[str], ' + 'not {}'.format(type(services).__name__)) for s in services: - if not isinstance(s, str): + if not isinstance(s, str): # pyright: reportUnnecessaryIsInstance=false raise TypeError('service names must be str, not {}'.format(type(s).__name__)) if isinstance(sig, int): @@ -2158,9 +2406,9 @@ def send_signal(self, sig: typing.Union[int, str], services: typing.List[str]): def get_checks( self, - level: CheckLevel = None, - names: typing.List[str] = None - ) -> typing.List[CheckInfo]: + level: Optional[CheckLevel] = None, + names: Optional[Iterable[str]] = None + ) -> List[CheckInfo]: """Get the check status for the configured checks. Args: @@ -2176,7 +2424,7 @@ def get_checks( if level is not None: query['level'] = level.value if names: - query['names'] = names + query['names'] = list(names) resp = self._request('GET', '/v1/checks', query) return [CheckInfo.from_dict(info) for info in resp['result']] @@ -2184,17 +2432,17 @@ def get_checks( class _FilesParser: """A limited purpose multi-part parser backed by files for memory efficiency.""" - def __init__(self, boundary: typing.Union[bytes, str]): - self._response = None - self._files = {} + def __init__(self, boundary: Union[bytes, str]): + self._response = None # type: Optional[_FilesResponse] # externally managed + self._part_type = None # type: Optional[Literal["response", "files"]] # externally managed + self._headers = None # type: Optional['Message'] # externally managed + self._files = {} # type: Dict[str, _Tempfile] # Prepare the MIME multipart boundary line patterns. if isinstance(boundary, str): boundary = boundary.encode() # State vars, as we may enter the feed() function multiple times. - self._headers = None - self._part_type = None self._response_data = bytearray() self._max_lookahead = 8 * 1024 * 1024 @@ -2230,15 +2478,16 @@ def _process_header(self, data: bytes): raise ProtocolError( 'unexpected name in content-disposition header: {!r}'.format(name)) - self._part_type = name + self._part_type = typing.cast('Literal["response", "files"]', name) - def _process_body(self, data: bytes, done=False): + def _process_body(self, data: bytes, done: bool = False): if self._part_type == 'response': self._response_data.extend(data) if done: if len(self._response_data) > self._max_lookahead: raise ProtocolError('response end marker not found') - self._response = json.loads(self._response_data.decode()) + resp = json.loads(self._response_data.decode()) + self._response = typing.cast('_FilesResponse', resp) self._response_data = bytearray() elif self._part_type == 'files': if done: @@ -2262,15 +2511,15 @@ def feed(self, data: bytes): """Provide more data to the running parser.""" self._parser.feed(data) - def _prepare_tempfile(self, filename): + def _prepare_tempfile(self, filename: str): tf = tempfile.NamedTemporaryFile(delete=False) - self._files[filename] = tf + self._files[filename] = tf # type: ignore # we have a custom protocol for it self.current_filename = filename def _get_open_tempfile(self): return self._files[self.current_filename] - def get_response(self): + def get_response(self) -> Optional['_FilesResponse']: """Return the deserialized JSON object from the multipart "response" field.""" return self._response @@ -2278,21 +2527,24 @@ def filenames(self): """Return a list of filenames from the "files" parts of the response.""" return list(self._files.keys()) - def get_file(self, path, encoding): + def get_file(self, path: str, encoding: Optional[str]) -> '_TextOrBinaryIO': """Return an open file object containing the data.""" mode = 'r' if encoding else 'rb' # We're using text-based file I/O purely for file encoding purposes, not for # newline normalization. newline='' serves the line endings as-is. newline = '' if encoding else None - return open(self._files[path].name, mode, encoding=encoding, newline=newline) + file_io = open(self._files[path].name, mode, + encoding=encoding, newline=newline) + # open() returns IO[Any] + return typing.cast('_TextOrBinaryIO', file_io) class _MultipartParser: def __init__( self, marker: bytes, - handle_header, - handle_body, + handle_header: '_HeaderHandler', + handle_body: '_BodyHandler', max_lookahead: int = 0, max_boundary_length: int = 0): r"""Configures a parser for mime multipart messages. @@ -2363,7 +2615,7 @@ def feed(self, data: bytes): self._pos = end else: # parse the part body - ii, nn, self._done = _next_part_boundary(self._buf, self._marker, start=self._pos) + ii, _, self._done = _next_part_boundary(self._buf, self._marker, start=self._pos) safe_bound = max(0, len(self._buf) - self._max_boundary_length) if ii != -1: # part body is finished @@ -2382,7 +2634,8 @@ def feed(self, data: bytes): return # waiting for more data -def _next_part_boundary(buf, marker, start=0): +def _next_part_boundary(buf: bytes, marker: bytes, start: int = 0 + ) -> Tuple[int, int, bool]: """Returns the index of the next boundary marker in buf beginning at start. Returns: diff --git a/pyproject.toml b/pyproject.toml index 169cf1c5a..f0ef98417 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ docstring-convention = "google" [tool.pyright] include = ["ops/jujuversion.py", "ops/log.py", "ops/model.py", "ops/version.py", - "ops/__init__.py", "ops/framework.py"] + "ops/__init__.py", "ops/framework.py", "opps/pebble.py"] pythonVersion = "3.5" # check no python > 3.5 features are used pythonPlatform = "All" typeCheckingMode = "strict" diff --git a/test/test_pebble.py b/test/test_pebble.py index 8c57a00f3..924d11014 100644 --- a/test/test_pebble.py +++ b/test/test_pebble.py @@ -863,6 +863,18 @@ def test_dict(self): check.exec['command'] = 'foo' self.assertEqual(d['exec'], {'command': 'echo foo'}) + def test_level_raw(self): + d = { + 'override': 'replace', + 'level': 'foobar!', + 'period': '10s', + 'timeout': '3s', + 'threshold': 5, + 'http': {'url': 'https://example.com/'}, + } + check = pebble.Check('chk-http', d) + self.assertEqual(check.level, 'foobar!') # remains a string + def test_equality(self): d = { 'override': 'replace', @@ -1259,7 +1271,7 @@ def setUp(self): def test_client_init(self): pebble.Client(socket_path='foo') # test that constructor runs - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): pebble.Client() # socket_path arg required def test_get_system_info(self): @@ -2507,6 +2519,32 @@ def test_get_checks_filters(self): ('GET', '/v1/checks', {'level': 'ready', 'names': ['chk2']}, None), ]) + def test_checklevel_conversion(self): + self.client.responses.append({ + "result": [ + { + "name": "chk2", + "level": "foobar!", + "status": "up", + "threshold": 3, + }, + ], + "status": "OK", + "status-code": 200, + "type": "sync" + }) + checks = self.client.get_checks(level=pebble.CheckLevel.READY, names=['chk2']) + self.assertEqual(len(checks), 1) + self.assertEqual(checks[0].name, 'chk2') + self.assertEqual(checks[0].level, 'foobar!') # stays a raw string + self.assertEqual(checks[0].status, pebble.CheckStatus.UP) + self.assertEqual(checks[0].failures, 0) + self.assertEqual(checks[0].threshold, 3) + + self.assertEqual(self.client.requests, [ + ('GET', '/v1/checks', {'level': 'ready', 'names': ['chk2']}, None), + ]) + @unittest.skipIf(sys.platform == 'win32', "Unix sockets don't work on Windows") class TestSocketClient(unittest.TestCase):