diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d74544a..482936a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,11 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. rev: v0.1.9 hooks: - # Run the linter. - id: ruff args: ["--fix"] - # Run the formatter. - id: ruff-format + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.350 + hooks: + - id: pyright diff --git a/findmy/accessory.py b/findmy/accessory.py index 2864243..b62f3ad 100644 --- a/findmy/accessory.py +++ b/findmy/accessory.py @@ -7,7 +7,9 @@ import logging from datetime import datetime, timedelta -from typing import Generator +from typing import Generator, overload + +from typing_extensions import override from .keys import KeyGenerator, KeyPair, KeyType from .util import crypto @@ -141,15 +143,26 @@ def _generate_keys(self, start: int, stop: int | None) -> Generator[KeyPair, Non ind += 1 + @override def __iter__(self) -> KeyGenerator: self._iter_ind = -1 return self + @override def __next__(self) -> KeyPair: self._iter_ind += 1 return self._get_keypair(self._iter_ind) + @overload + def __getitem__(self, val: int) -> KeyPair: + ... + + @overload + def __getitem__(self, val: slice) -> Generator[KeyPair, None, None]: + ... + + @override def __getitem__(self, val: int | slice) -> KeyPair | Generator[KeyPair, None, None]: if isinstance(val, int): if val < 0: diff --git a/findmy/keys.py b/findmy/keys.py index 42d5ce5..f51c5bf 100644 --- a/findmy/keys.py +++ b/findmy/keys.py @@ -9,6 +9,7 @@ from typing import Generator, Generic, TypeVar, overload from cryptography.hazmat.primitives.asymmetric import ec +from typing_extensions import override from .util import crypto @@ -49,10 +50,12 @@ def hashed_adv_key_b64(self) -> str: """Return the hashed advertised (public) key as a base64-encoded string.""" return base64.b64encode(self.hashed_adv_key_bytes).decode("ascii") + @override def __hash__(self) -> int: return crypto.bytes_to_int(self.adv_key_bytes) - def __eq__(self, other: HasPublicKey) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, HasPublicKey): return NotImplemented @@ -107,6 +110,7 @@ def private_key_b64(self) -> str: return base64.b64encode(self.private_key_bytes).decode("ascii") @property + @override def adv_key_bytes(self) -> bytes: """Return the advertised (public) key as bytes.""" key_bytes = self._priv_key.public_key().public_numbers().x @@ -116,6 +120,7 @@ def dh_exchange(self, other_pub_key: ec.EllipticCurvePublicKey) -> bytes: """Do a Diffie-Hellman key exchange using another EC public key.""" return self._priv_key.exchange(ec.ECDH(), other_pub_key) + @override def __repr__(self) -> str: return f'KeyPair(public_key="{self.adv_key_b64}", type={self.key_type})' @@ -135,11 +140,13 @@ def __next__(self) -> K: return NotImplemented @overload + @abstractmethod def __getitem__(self, val: int) -> K: ... @overload - def __getitem__(self, slc: slice) -> Generator[K, None, None]: + @abstractmethod + def __getitem__(self, val: slice) -> Generator[K, None, None]: ... @abstractmethod diff --git a/findmy/reports/account.py b/findmy/reports/account.py index dd88fec..5413162 100644 --- a/findmy/reports/account.py +++ b/findmy/reports/account.py @@ -11,18 +11,14 @@ import uuid from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone -from typing import ( - TYPE_CHECKING, - Any, - Sequence, - TypedDict, -) +from typing import TYPE_CHECKING, Any, Sequence, TypedDict import bs4 import srp._pysrp as srp from cryptography.hazmat.primitives import hashes, padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from typing_extensions import override from findmy.util.closable import Closable from findmy.util.errors import InvalidCredentialsError, UnhandledProtocolError @@ -40,6 +36,7 @@ if TYPE_CHECKING: from findmy.keys import KeyPair + from findmy.util.types import MaybeCoro from .anisette import BaseAnisetteProvider @@ -83,7 +80,7 @@ def _decrypt_cbc(session_key: bytes, data: bytes) -> bytes: def _extract_phone_numbers(html: str) -> list[dict]: soup = bs4.BeautifulSoup(html, features="html.parser") - data_elem = soup.find("script", **{"class": "boot_args"}) + data_elem = soup.find("script", {"class": "boot_args"}) if not data_elem: msg = "Could not find HTML element containing phone numbers" raise RuntimeError(msg) @@ -103,7 +100,7 @@ def login_state(self) -> LoginState: @property @abstractmethod - def account_name(self) -> str: + def account_name(self) -> str | None: """ The name of the account as reported by Apple. @@ -155,12 +152,12 @@ def restore(self, data: dict) -> None: raise NotImplementedError @abstractmethod - def login(self, username: str, password: str) -> LoginState: + def login(self, username: str, password: str) -> MaybeCoro[LoginState]: """Log in to an Apple account using a username and password.""" raise NotImplementedError @abstractmethod - def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: + def get_2fa_methods(self) -> MaybeCoro[Sequence[BaseSecondFactorMethod]]: """ Get a list of 2FA methods that can be used as a secondary challenge. @@ -169,7 +166,7 @@ def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: raise NotImplementedError @abstractmethod - def sms_2fa_request(self, phone_number_id: int) -> None: + def sms_2fa_request(self, phone_number_id: int) -> MaybeCoro[None]: """ Request a 2FA code to be sent to a specific phone number ID. @@ -178,7 +175,7 @@ def sms_2fa_request(self, phone_number_id: int) -> None: raise NotImplementedError @abstractmethod - def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState: + def sms_2fa_submit(self, phone_number_id: int, code: str) -> MaybeCoro[LoginState]: """ Submit a 2FA code that was sent to a specific phone number ID. @@ -192,7 +189,7 @@ def fetch_reports( keys: Sequence[KeyPair], date_from: datetime, date_to: datetime | None, - ) -> dict[KeyPair, list[LocationReport]]: + ) -> MaybeCoro[dict[KeyPair, list[LocationReport]]]: """ Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`. @@ -205,7 +202,7 @@ def fetch_last_reports( self, keys: Sequence[KeyPair], hours: int = 7 * 24, - ) -> dict[KeyPair, list[LocationReport]]: + ) -> MaybeCoro[dict[KeyPair, list[LocationReport]]]: """ Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours. @@ -214,7 +211,7 @@ def fetch_last_reports( raise NotImplementedError @abstractmethod - def get_anisette_headers(self, serial: str = "0") -> dict[str, str]: + def get_anisette_headers(self, serial: str = "0") -> MaybeCoro[dict[str, str]]: """ Retrieve a complete dictionary of Anisette headers. @@ -273,6 +270,7 @@ def _set_login_state( return state @property + @override def login_state(self) -> LoginState: """See `BaseAppleAccount.login_state`.""" return self._login_state @@ -283,6 +281,7 @@ def login_state(self) -> LoginState: LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA, ) + @override def account_name(self) -> str | None: """See `BaseAppleAccount.account_name`.""" return self._account_info["account_name"] if self._account_info else None @@ -293,6 +292,7 @@ def account_name(self) -> str | None: LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA, ) + @override def first_name(self) -> str | None: """See `BaseAppleAccount.first_name`.""" return self._account_info["first_name"] if self._account_info else None @@ -303,10 +303,12 @@ def first_name(self) -> str | None: LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA, ) + @override def last_name(self) -> str | None: """See `BaseAppleAccount.last_name`.""" return self._account_info["last_name"] if self._account_info else None + @override def export(self) -> dict: """See `BaseAppleAccount.export`.""" return { @@ -322,6 +324,7 @@ def export(self) -> dict: }, } + @override def restore(self, data: dict) -> None: """See `BaseAppleAccount.restore`.""" try: @@ -348,6 +351,7 @@ async def close(self) -> None: await self._http.close() @require_login_state(LoginState.LOGGED_OUT) + @override async def login(self, username: str, password: str) -> LoginState: """See `BaseAppleAccount.login`.""" # LOGGED_OUT -> (REQUIRE_2FA or AUTHENTICATED) @@ -359,7 +363,8 @@ async def login(self, username: str, password: str) -> LoginState: return await self._login_mobileme() @require_login_state(LoginState.REQUIRE_2FA) - async def get_2fa_methods(self) -> list[AsyncSecondFactorMethod]: + @override + async def get_2fa_methods(self) -> Sequence[AsyncSecondFactorMethod]: """See `BaseAppleAccount.get_2fa_methods`.""" methods: list[AsyncSecondFactorMethod] = [] @@ -370,8 +375,8 @@ async def get_2fa_methods(self) -> list[AsyncSecondFactorMethod]: methods.extend( AsyncSmsSecondFactor( self, - number.get("id"), - number.get("numberWithDialCode"), + number.get("id") or -1, + number.get("numberWithDialCode") or "-", ) for number in phone_numbers ) @@ -381,6 +386,7 @@ async def get_2fa_methods(self) -> list[AsyncSecondFactorMethod]: return methods @require_login_state(LoginState.REQUIRE_2FA) + @override async def sms_2fa_request(self, phone_number_id: int) -> None: """See `BaseAppleAccount.sms_2fa_request`.""" data = {"phoneNumber": {"id": phone_number_id}, "mode": "sms"} @@ -392,6 +398,7 @@ async def sms_2fa_request(self, phone_number_id: int) -> None: ) @require_login_state(LoginState.REQUIRE_2FA) + @override async def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState: """See `BaseAppleAccount.sms_2fa_submit`.""" data = { @@ -435,8 +442,9 @@ async def fetch_raw_reports(self, start: int, end: int, ids: list[str]) -> dict[ raise UnhandledProtocolError(msg) return resp - + @require_login_state(LoginState.LOGGED_IN) + @override async def fetch_reports( self, keys: Sequence[KeyPair], @@ -453,6 +461,7 @@ async def fetch_reports( ) @require_login_state(LoginState.LOGGED_IN) + @override async def fetch_last_reports( self, keys: Sequence[KeyPair], @@ -521,11 +530,11 @@ async def _gsa_authenticate( logging.debug("Decrypting SPD data in response") - spd = _decrypt_cbc(usr.get_session_key(), r["spd"]) + spd = _decrypt_cbc(usr.get_session_key() or b"", r["spd"]) spd = decode_plist(spd) logging.debug("Received account information") - self._account_info: _AccountInfo = { + self._account_info = { "account_name": spd.get("acname"), "first_name": spd.get("fn"), "last_name": spd.get("ln"), @@ -575,7 +584,7 @@ async def _login_mobileme(self) -> LoginState: resp = await self._http.post( "https://setup.icloud.com/setup/iosbuddy/loginDelegates", - auth=(self._username, self._login_state_data["idms_pet"]), + auth=(self._username or "", self._login_state_data["idms_pet"]), data=data, headers=headers, ) @@ -597,7 +606,7 @@ async def _sms_2fa_request( self, method: str, url: str, - data: dict | None = None, + data: dict[str, Any] | None = None, ) -> str: adsid = self._login_state_data["adsid"] idms_token = self._login_state_data["idms_token"] @@ -617,7 +626,7 @@ async def _sms_2fa_request( r = await self._http.request( method, url, - json=data, + json=data or {}, headers=headers, ) if not r.ok: @@ -662,6 +671,7 @@ async def _gsa_request(self, params: dict[str, Any]) -> dict[Any, Any]: raise UnhandledProtocolError(msg) return resp.plist()["Response"] + @override async def get_anisette_headers(self, serial: str = "0") -> dict[str, str]: """See `BaseAppleAccount.get_anisette_headers`.""" return await self._anisette.get_headers(self._uid, self._devid, serial) @@ -696,39 +706,47 @@ async def close(self) -> None: await self._asyncacc.close() @property + @override def login_state(self) -> LoginState: """See `AsyncAppleAccount.login_state`.""" return self._asyncacc.login_state @property - def account_name(self) -> str: + @override + def account_name(self) -> str | None: """See `AsyncAppleAccount.login_state`.""" return self._asyncacc.account_name @property + @override def first_name(self) -> str | None: """See `AsyncAppleAccount.first_name`.""" return self._asyncacc.first_name @property + @override def last_name(self) -> str | None: """See `AsyncAppleAccount.last_name`.""" return self._asyncacc.last_name + @override def export(self) -> dict: """See `AsyncAppleAccount.export`.""" return self._asyncacc.export() + @override def restore(self, data: dict) -> None: """See `AsyncAppleAccount.restore`.""" return self._asyncacc.restore(data) + @override def login(self, username: str, password: str) -> LoginState: """See `AsyncAppleAccount.login`.""" coro = self._asyncacc.login(username, password) return self._loop.run_until_complete(coro) - def get_2fa_methods(self) -> list[SyncSecondFactorMethod]: + @override + def get_2fa_methods(self) -> Sequence[SyncSecondFactorMethod]: """See `AsyncAppleAccount.get_2fa_methods`.""" coro = self._asyncacc.get_2fa_methods() methods = self._loop.run_until_complete(coro) @@ -746,16 +764,19 @@ def get_2fa_methods(self) -> list[SyncSecondFactorMethod]: return res + @override def sms_2fa_request(self, phone_number_id: int) -> None: """See `AsyncAppleAccount.sms_2fa_request`.""" coro = self._asyncacc.sms_2fa_request(phone_number_id) return self._loop.run_until_complete(coro) + @override def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState: """See `AsyncAppleAccount.sms_2fa_submit`.""" coro = self._asyncacc.sms_2fa_submit(phone_number_id, code) return self._loop.run_until_complete(coro) + @override def fetch_reports( self, keys: Sequence[KeyPair], @@ -766,6 +787,7 @@ def fetch_reports( coro = self._asyncacc.fetch_reports(keys, date_from, date_to) return self._loop.run_until_complete(coro) + @override def fetch_last_reports( self, keys: Sequence[KeyPair], @@ -775,6 +797,7 @@ def fetch_last_reports( coro = self._asyncacc.fetch_last_reports(keys, hours) return self._loop.run_until_complete(coro) + @override def get_anisette_headers(self, serial: str = "0") -> dict[str, str]: """See `AsyncAppleAccount.get_anisette_headers`.""" coro = self._asyncacc.get_anisette_headers(serial) diff --git a/findmy/reports/anisette.py b/findmy/reports/anisette.py index 8ce9d37..bdf4b2e 100644 --- a/findmy/reports/anisette.py +++ b/findmy/reports/anisette.py @@ -70,6 +70,7 @@ def __init__(self, server_url: str) -> None: logging.info("Using remote anisette server: %s", self._server_url) + @override async def _get_base_headers(self) -> dict[str, str]: r = await self._http.get(self._server_url) headers = r.json() @@ -94,6 +95,7 @@ def __init__(self) -> None: """Initialize the provider.""" super().__init__() + @override async def _get_base_headers(self) -> dict[str, str]: return NotImplemented diff --git a/findmy/reports/reports.py b/findmy/reports/reports.py index e224bd7..cf818d6 100644 --- a/findmy/reports/reports.py +++ b/findmy/reports/reports.py @@ -10,7 +10,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from typing_extensions import Unpack +from typing_extensions import override, Unpack from findmy.keys import KeyPair from findmy.util.http import HttpSession @@ -161,11 +161,12 @@ def __lt__(self, other: LocationReport) -> bool: return self.timestamp < other.timestamp return NotImplemented + @override def __repr__(self) -> str: """Human-readable string representation of the location report.""" return ( - f"" + f"KeyReport(key={self._key.hashed_adv_key_b64}, timestamp={self._timestamp}," + f" lat={self._lat}, lng={self._lng})" ) diff --git a/findmy/reports/state.py b/findmy/reports/state.py index 5cf751e..3d832ff 100644 --- a/findmy/reports/state.py +++ b/findmy/reports/state.py @@ -1,18 +1,13 @@ """Code related to internal account state handling.""" from enum import Enum from functools import wraps -from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar +from typing import Callable, Concatenate, ParamSpec, TypeVar -from findmy.util.errors import InvalidStateError +from typing_extensions import override -if TYPE_CHECKING: - # noinspection PyUnresolvedReferences - from .account import BaseAppleAccount +from findmy.util.errors import InvalidStateError -P = ParamSpec("P") -R = TypeVar("R") -A = TypeVar("A", bound="BaseAppleAccount") -F = Callable[Concatenate[A, P], R] +from .account import BaseAppleAccount class LoginState(Enum): @@ -35,17 +30,28 @@ def __lt__(self, other: "LoginState") -> bool: return NotImplemented + @override def __repr__(self) -> str: """Human-readable string representation of the state.""" return self.__str__() -def require_login_state(*states: LoginState) -> Callable[[F], F]: +_P = ParamSpec("_P") +_R = TypeVar("_R") +_A = TypeVar("_A", bound="BaseAppleAccount") +_F = Callable[Concatenate[_A, _P], _R] + + +def require_login_state(*states: LoginState) -> Callable[[_F], _F]: """Enforce a login state as precondition for a method.""" - def decorator(func: F) -> F: + def decorator(func: _F) -> _F: @wraps(func) - def wrapper(acc: A, *args: P.args, **kwargs: P.kwargs) -> R: + def wrapper(acc: _A, *args: _P.args, **kwargs: _P.kwargs) -> _R: # pyright: ignore [reportInvalidTypeVarUse] + if not isinstance(args[0], BaseAppleAccount): + msg = "This decorator can only be used on instances of BaseAppleAccount." + raise TypeError(msg) + if acc.login_state not in states: msg = ( f"Invalid login state! Currently: {acc.login_state}" diff --git a/findmy/reports/twofactor.py b/findmy/reports/twofactor.py index c2ff81e..005beef 100644 --- a/findmy/reports/twofactor.py +++ b/findmy/reports/twofactor.py @@ -1,6 +1,10 @@ """Public classes related to handling two-factor authentication.""" -from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, TypeVar +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, TypeVar + +from typing_extensions import override + +from findmy.util.types import MaybeCoro from .state import LoginState @@ -8,23 +12,23 @@ # noinspection PyUnresolvedReferences from .account import AppleAccount, AsyncAppleAccount, BaseAppleAccount -T = TypeVar("T", bound="BaseAppleAccount") +_AccType = TypeVar("_AccType", bound="BaseAppleAccount") -class BaseSecondFactorMethod(metaclass=ABCMeta): +class BaseSecondFactorMethod(ABC, Generic[_AccType]): """Base class for a second-factor authentication method for an Apple account.""" - def __init__(self, account: T) -> None: + def __init__(self, account: _AccType) -> None: """Initialize the second-factor method.""" - self._account: T = account + self._account: _AccType = account @property - def account(self) -> T: + def account(self) -> _AccType: """The account associated with the second-factor method.""" return self._account @abstractmethod - def request(self) -> None: + def request(self) -> MaybeCoro[None]: """ Put in a request for the second-factor challenge. @@ -33,12 +37,12 @@ def request(self) -> None: raise NotImplementedError @abstractmethod - def submit(self, code: str) -> LoginState: + def submit(self, code: str) -> MaybeCoro[LoginState]: """Submit a code to complete the second-factor challenge.""" raise NotImplementedError -class AsyncSecondFactorMethod(BaseSecondFactorMethod, metaclass=ABCMeta): +class AsyncSecondFactorMethod(BaseSecondFactorMethod, ABC): """ An asynchronous implementation of a second-factor authentication method. @@ -50,12 +54,25 @@ def __init__(self, account: "AsyncAppleAccount") -> None: super().__init__(account) @property + @override def account(self) -> "AsyncAppleAccount": """The account associated with the second-factor method.""" return self._account + @override + @abstractmethod + async def request(self) -> None: + """See `BaseSecondFactorMethod.request`.""" + raise NotImplementedError + + @override + @abstractmethod + async def submit(self, code: str) -> LoginState: + """See `BaseSecondFactorMethod.submit`.""" + raise NotImplementedError + -class SyncSecondFactorMethod(BaseSecondFactorMethod, metaclass=ABCMeta): +class SyncSecondFactorMethod(BaseSecondFactorMethod, ABC): """ A synchronous implementation of a second-factor authentication method. @@ -67,12 +84,25 @@ def __init__(self, account: "AppleAccount") -> None: super().__init__(account) @property + @override def account(self) -> "AppleAccount": """The account associated with the second-factor method.""" return self._account + @override + @abstractmethod + def request(self) -> None: + """See `BaseSecondFactorMethod.request`.""" + raise NotImplementedError + + @override + @abstractmethod + def submit(self, code: str) -> LoginState: + """See `BaseSecondFactorMethod.submit`.""" + raise NotImplementedError + -class SmsSecondFactorMethod(BaseSecondFactorMethod, metaclass=ABCMeta): +class SmsSecondFactorMethod(BaseSecondFactorMethod, ABC): """Base class for SMS-based two-factor authentication.""" @property @@ -112,11 +142,13 @@ def __init__( self._phone_number: str = phone_number @property + @override def phone_number_id(self) -> int: """The phone number's ID. You most likely don't need this.""" return self._phone_number_id @property + @override def phone_number(self) -> str: """ The 2FA method's phone number. @@ -125,10 +157,12 @@ def phone_number(self) -> str: """ return self._phone_number + @override async def request(self) -> None: """Request an SMS to the corresponding phone number containing a 2FA code.""" return await self.account.sms_2fa_request(self._phone_number_id) + @override async def submit(self, code: str) -> LoginState: """See `BaseSecondFactorMethod.submit`.""" return await self.account.sms_2fa_submit(self._phone_number_id, code) @@ -154,19 +188,23 @@ def __init__( self._phone_number: str = phone_number @property + @override def phone_number_id(self) -> int: """See `AsyncSmsSecondFactor.phone_number_id`.""" return self._phone_number_id @property + @override def phone_number(self) -> str: """See `AsyncSmsSecondFactor.phone_number`.""" return self._phone_number + @override def request(self) -> None: """See `AsyncSmsSecondFactor.request`.""" return self.account.sms_2fa_request(self._phone_number_id) + @override def submit(self, code: str) -> LoginState: """See `AsyncSmsSecondFactor.submit`.""" return self.account.sms_2fa_submit(self._phone_number_id, code) diff --git a/findmy/scanner/scanner.py b/findmy/scanner/scanner.py index 044a01e..0371644 100644 --- a/findmy/scanner/scanner.py +++ b/findmy/scanner/scanner.py @@ -4,12 +4,17 @@ import asyncio import logging import time -from typing import Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator -import bleak +from bleak import BleakScanner +from typing_extensions import override from findmy.keys import HasPublicKey +if TYPE_CHECKING: + from bleak.backends.device import BLEDevice + from bleak.backends.scanner import AdvertisementData + logging.getLogger(__name__) @@ -58,6 +63,7 @@ def additional_data(self) -> dict[Any, Any]: return self._additional_data @property + @override def adv_key_bytes(self) -> bytes: """See `HasPublicKey.adv_key_bytes`.""" return self._public_key @@ -102,6 +108,7 @@ def from_payload( return OfflineFindingDevice(mac_bytes, status, pubkey, hint, additional_data) + @override def __repr__(self) -> str: """Human-readable string representation of an OfflineFindingDevice.""" return ( @@ -109,16 +116,6 @@ def __repr__(self) -> str: f" status={self.status}, hint={self.hint})" ) - def __eq__(self, other: OfflineFindingDevice) -> bool: - """Check if two OfflineFindingDevices are equal by comparing their MAC addresses.""" - if not isinstance(other, OfflineFindingDevice): - return False - return other.mac_address == self.mac_address - - def __hash__(self) -> int: - """Hash an OfflineFindingDevice. This is simply the MAC address as an integer.""" - return int.from_bytes(self._mac_bytes, "big") - class OfflineFindingScanner: """BLE scanner that searches for `OfflineFindingDevice`s.""" @@ -134,12 +131,10 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: You most likely do not want to use this yourself; check out `OfflineFindingScanner.create` instead. """ - self._scanner: bleak.BleakScanner = bleak.BleakScanner(self._scan_callback) + self._scanner: BleakScanner = BleakScanner(self._scan_callback) self._loop = loop - self._device_fut: asyncio.Future[ - (bleak.BLEDevice, bleak.AdvertisementData) - ] = loop.create_future() + self._device_fut: asyncio.Future[tuple[BLEDevice, AdvertisementData]] = loop.create_future() self._scanner_count: int = 0 @@ -165,8 +160,8 @@ async def _stop_scan(self) -> None: async def _scan_callback( self, - device: bleak.BLEDevice, - data: bleak.AdvertisementData, + device: BLEDevice, + data: AdvertisementData, ) -> None: self._device_fut.set_result((device, data)) self._device_fut = self._loop.create_future() @@ -186,7 +181,7 @@ async def scan_for( timeout: float = 10, *, extend_timeout: bool = False, - ) -> AsyncGenerator[OfflineFindingDevice]: + ) -> AsyncGenerator[OfflineFindingDevice, None]: """ Scan for `OfflineFindingDevice`s for up to `timeout` seconds. diff --git a/findmy/util/http.py b/findmy/util/http.py index 6689dbf..5512f5d 100644 --- a/findmy/util/http.py +++ b/findmy/util/http.py @@ -3,9 +3,10 @@ import json import logging -from typing import Any, ParamSpec +from typing import Any, TypedDict from aiohttp import BasicAuth, ClientSession, ClientTimeout +from typing_extensions import Unpack from .closable import Closable from .parsers import decode_plist @@ -13,6 +14,13 @@ logging.getLogger(__name__) +class _HttpRequestOptions(TypedDict, total=False): + json: dict[str, Any] + headers: dict[str, str] + auth: tuple[str, str] | BasicAuth + data: bytes + + class HttpResponse: """Response of a request made by `HttpSession`.""" @@ -49,9 +57,6 @@ def plist(self) -> dict[Any, Any]: return data -P = ParamSpec("P") - - class HttpSession(Closable): """Asynchronous HTTP session manager. For internal use only.""" @@ -60,10 +65,13 @@ def __init__(self) -> None: # noqa: D107 self._session: ClientSession | None = None - async def _ensure_session(self) -> None: - if self._session is None: - logging.debug("Creating aiohttp session") - self._session = ClientSession(timeout=ClientTimeout(total=5)) + async def _get_session(self) -> ClientSession: + if self._session is not None: + return self._session + + logging.debug("Creating aiohttp session") + self._session = ClientSession(timeout=ClientTimeout(total=5)) + return self._session async def close(self) -> None: """Close the underlying session. Should be called when session will no longer be used.""" @@ -76,33 +84,31 @@ async def request( self, method: str, url: str, - auth: tuple[str] | None = None, - **kwargs: P.kwargs, + **kwargs: Unpack[_HttpRequestOptions], ) -> HttpResponse: """ Make an HTTP request. Keyword arguments will directly be passed to `aiohttp.ClientSession.request`. """ - await self._ensure_session() + session = await self._get_session() - basic_auth = None - if auth is not None: - basic_auth = BasicAuth(auth[0], auth[1]) + auth = kwargs.get("auth") + if isinstance(auth, tuple): + kwargs["auth"] = BasicAuth(auth[0], auth[1]) - async with await self._session.request( + async with await session.request( method, url, - auth=basic_auth, ssl=False, **kwargs, ) as r: return HttpResponse(r.status, await r.content.read()) - async def get(self, url: str, **kwargs: P.kwargs) -> HttpResponse: + async def get(self, url: str, **kwargs: Unpack[_HttpRequestOptions]) -> HttpResponse: """Alias for `HttpSession.request("GET", ...)`.""" return await self.request("GET", url, **kwargs) - async def post(self, url: str, **kwargs: P.kwargs) -> HttpResponse: + async def post(self, url: str, **kwargs: Unpack[_HttpRequestOptions]) -> HttpResponse: """Alias for `HttpSession.request("POST", ...)`.""" return await self.request("POST", url, **kwargs) diff --git a/findmy/util/types.py b/findmy/util/types.py new file mode 100644 index 0000000..8c754aa --- /dev/null +++ b/findmy/util/types.py @@ -0,0 +1,7 @@ +"""Utility types.""" + +from typing import Coroutine, TypeVar + +T = TypeVar("T") + +MaybeCoro = T | Coroutine[None, None, T] diff --git a/poetry.lock b/poetry.lock index 4af3779..c78d880 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1073,6 +1073,24 @@ files = [ [package.dependencies] pyobjc-core = ">=9.2" +[[package]] +name = "pyright" +version = "1.1.350" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyright-1.1.350-py3-none-any.whl", hash = "sha256:f1dde6bcefd3c90aedbe9dd1c573e4c1ddbca8c74bf4fa664dd3b1a599ac9a66"}, + {file = "pyright-1.1.350.tar.gz", hash = "sha256:a8ba676de3a3737ea4d8590604da548d4498cc5ee9ee00b1a403c6db987916c6"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" + +[package.extras] +all = ["twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] + [[package]] name = "pyyaml" version = "6.0.1" @@ -1760,4 +1778,4 @@ scan = ["bleak"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "f0a6463477183b86d152b71aa14404fea80a3b28df20c43782eb56de00db8d91" +content-hash = "696a56ccbba231e3ec702aaee911977819b996d21074b37807c42a45d107c7ab" diff --git a/pyproject.toml b/pyproject.toml index c9e277a..02c3e8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,15 @@ scan = ["bleak"] pre-commit = "^3.6.0" sphinx = "^7.2.6" sphinx-autoapi = "^3.0.0" +pyright = "^1.1.350" + +[tool.pyright] +venvPath = "." +venv = ".venv" + +# rule overrides +typeCheckingMode = "standard" +reportImplicitOverride = true [tool.ruff] exclude = [