Skip to content

Commit

Permalink
Merge pull request #6 from malmeloo/fix/stateful-report-fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
malmeloo authored Feb 10, 2024
2 parents 55d6c9f + 223d421 commit 7e3a1ed
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 97 deletions.
65 changes: 44 additions & 21 deletions findmy/reports/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC

from findmy.util import HttpSession, decode_plist
from findmy.util.closable import Closable
from findmy.util.errors import InvalidCredentialsError, UnhandledProtocolError
from findmy.util.http import HttpSession, decode_plist

from .reports import KeyReport, fetch_reports
from .reports import LocationReport, LocationReportsFetcher
from .state import LoginState, require_login_state
from .twofactor import (
AsyncSecondFactorMethod,
Expand Down Expand Up @@ -91,7 +92,7 @@ def _extract_phone_numbers(html: str) -> list[dict]:
return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", [])


class BaseAppleAccount(ABC):
class BaseAppleAccount(Closable, ABC):
"""Base class for an Apple account."""

@property
Expand Down Expand Up @@ -190,8 +191,8 @@ def fetch_reports(
self,
keys: Sequence[KeyPair],
date_from: datetime,
date_to: datetime,
) -> dict[KeyPair, list[KeyReport]]:
date_to: datetime | None,
) -> dict[KeyPair, list[LocationReport]]:
"""
Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`.
Expand All @@ -204,7 +205,7 @@ def fetch_last_reports(
self,
keys: Sequence[KeyPair],
hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]:
) -> dict[KeyPair, list[LocationReport]]:
"""
Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours.
Expand Down Expand Up @@ -238,6 +239,8 @@ def __init__(
:param user_id: An optional user ID to use. Will be auto-generated if missing.
:param device_id: An optional device ID to use. Will be auto-generated if missing.
"""
super().__init__()

self._anisette: BaseAnisetteProvider = anisette
self._uid: str = user_id or str(uuid.uuid4())
self._devid: str = device_id or str(uuid.uuid4())
Expand All @@ -251,6 +254,7 @@ def __init__(
self._account_info: _AccountInfo | None = None

self._http: HttpSession = HttpSession()
self._reports: LocationReportsFetcher = LocationReportsFetcher(self)

def _set_login_state(
self,
Expand Down Expand Up @@ -411,20 +415,38 @@ async def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState:
# AUTHENTICATED -> LOGGED_IN
return await self._login_mobileme()

@require_login_state(LoginState.LOGGED_IN)
async def fetch_raw_reports(self, start: int, end: int, ids: list[str]) -> dict[str, Any]:
"""Make a request for location reports, returning raw data."""
auth = (
self._login_state_data["dsid"],
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
)
data = {"search": [{"startDate": start, "endDate": end, "ids": ids}]}
r = await self._http.post(
"https://gateway.icloud.com/acsnservice/fetch",
auth=auth,
headers=await self.get_anisette_headers(),
json=data,
)
resp = r.json()
if not r.ok or resp["statusCode"] != "200":
msg = f"Failed to fetch reports: {resp['statusCode']}"
raise UnhandledProtocolError(msg)

return resp

@require_login_state(LoginState.LOGGED_IN)
async def fetch_reports(
self,
keys: Sequence[KeyPair],
date_from: datetime,
date_to: datetime,
) -> dict[KeyPair, list[KeyReport]]:
date_to: datetime | None,
) -> dict[KeyPair, list[LocationReport]]:
"""See `BaseAppleAccount.fetch_reports`."""
anisette_headers = await self.get_anisette_headers()
date_to = date_to or datetime.now().astimezone()

return await fetch_reports(
self._login_state_data["dsid"],
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
anisette_headers,
return await self._reports.fetch_reports(
date_from,
date_to,
keys,
Expand All @@ -435,7 +457,7 @@ async def fetch_last_reports(
self,
keys: Sequence[KeyPair],
hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]:
) -> dict[KeyPair, list[LocationReport]]:
"""See `BaseAppleAccount.fetch_last_reports`."""
end = datetime.now(tz=timezone.utc)
start = end - timedelta(hours=hours)
Expand Down Expand Up @@ -667,10 +689,11 @@ def __init__(
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)

def __del__(self) -> None:
"""Gracefully close the async instance's session when garbage collected."""
coro = self._asyncacc.close()
return self._loop.run_until_complete(coro)
super().__init__(self._loop)

async def close(self) -> None:
"""See `AsyncAppleAccount.close`."""
await self._asyncacc.close()

@property
def login_state(self) -> LoginState:
Expand Down Expand Up @@ -737,8 +760,8 @@ def fetch_reports(
self,
keys: Sequence[KeyPair],
date_from: datetime,
date_to: datetime,
) -> dict[KeyPair, list[KeyReport]]:
date_to: datetime | None,
) -> dict[KeyPair, list[LocationReport]]:
"""See `AsyncAppleAccount.fetch_reports`."""
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
return self._loop.run_until_complete(coro)
Expand All @@ -747,7 +770,7 @@ def fetch_last_reports(
self,
keys: Sequence[KeyPair],
hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]:
) -> dict[KeyPair, list[LocationReport]]:
"""See `AsyncAppleAccount.fetch_last_reports`."""
coro = self._asyncacc.fetch_last_reports(keys, hours)
return self._loop.run_until_complete(coro)
Expand Down
17 changes: 10 additions & 7 deletions findmy/reports/anisette.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone

from findmy.util import HttpSession
from typing_extensions import override

from findmy.util.closable import Closable
from findmy.util.http import HttpSession


def _gen_meta_headers(
Expand All @@ -30,18 +33,13 @@ def _gen_meta_headers(
}


class BaseAnisetteProvider(ABC):
class BaseAnisetteProvider(Closable, ABC):
"""Abstract base class for Anisette providers."""

@abstractmethod
async def _get_base_headers(self) -> dict[str, str]:
raise NotImplementedError

@abstractmethod
async def close(self) -> None:
"""Close any underlying sessions. Call when the provider will no longer be used."""
raise NotImplementedError

async def get_headers(
self,
user_id: str,
Expand All @@ -64,6 +62,8 @@ class RemoteAnisetteProvider(BaseAnisetteProvider):

def __init__(self, server_url: str) -> None:
"""Initialize the provider with URL to te remote server."""
super().__init__()

self._server_url = server_url

self._http = HttpSession()
Expand All @@ -79,6 +79,7 @@ async def _get_base_headers(self) -> dict[str, str]:
"X-Apple-I-MD-M": headers["X-Apple-I-MD-M"],
}

@override
async def close(self) -> None:
"""See `AnisetteProvider.close`."""
await self._http.close()
Expand All @@ -91,9 +92,11 @@ class LocalAnisetteProvider(BaseAnisetteProvider):

def __init__(self) -> None:
"""Initialize the provider."""
super().__init__()

async def _get_base_headers(self) -> dict[str, str]:
return NotImplemented

@override
async def close(self) -> None:
"""See `AnisetteProvider.close`."""
Loading

0 comments on commit 7e3a1ed

Please # to comment.