diff --git a/pydrawise/hybrid.py b/pydrawise/hybrid.py index 9fc9fd1..126d5f0 100644 --- a/pydrawise/hybrid.py +++ b/pydrawise/hybrid.py @@ -77,6 +77,8 @@ def __init__( auth: HybridAuth, app_id: str = DEFAULT_APP_ID, gql_client: Hydrawise | None = None, + gql_throttle: Throttler | None = None, + rest_throttle: Throttler | None = None, ) -> None: if gql_client is None: gql_client = Hydrawise(auth, app_id) @@ -86,12 +88,16 @@ def __init__( self._user: User | None = None self._controllers: dict[int, Controller] = {} self._zones: dict[int, Zone] = {} - self._gql_throttle = Throttler( - epoch_interval=timedelta(minutes=30), tokens_per_epoch=2 - ) - self._rest_throttle = Throttler( - epoch_interval=timedelta(minutes=1), tokens_per_epoch=2 - ) + if gql_throttle is None: + gql_throttle = Throttler( + epoch_interval=timedelta(minutes=30), tokens_per_epoch=5 + ) + self._gql_throttle: Throttler = gql_throttle + if rest_throttle is None: + rest_throttle = Throttler( + epoch_interval=timedelta(minutes=1), tokens_per_epoch=2 + ) + self._rest_throttle: Throttler = rest_throttle async def get_user(self, fetch_zones: bool = True) -> User: async with self._lock: diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 580b067..c5faf05 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -5,9 +5,9 @@ from freezegun import freeze_time from pytest import fixture -from pydrawise import hybrid from pydrawise.auth import HybridAuth from pydrawise.client import Hydrawise +from pydrawise.hybrid import HybridClient, Throttler FROZEN_TIME = "2023-01-01 01:00:00" @@ -26,12 +26,21 @@ def mock_gql_client(): @fixture def api(hybrid_auth, mock_gql_client): - yield hybrid.HybridClient(hybrid_auth, gql_client=mock_gql_client) + yield HybridClient( + hybrid_auth, + gql_client=mock_gql_client, + gql_throttle=Throttler( + epoch_interval=timedelta(minutes=30), tokens_per_epoch=2 + ), + rest_throttle=Throttler( + epoch_interval=timedelta(minutes=1), tokens_per_epoch=2 + ), + ) def test_throttler(): with freeze_time(FROZEN_TIME) as frozen_time: - throttle = hybrid.Throttler(epoch_interval=timedelta(seconds=60)) + throttle = Throttler(epoch_interval=timedelta(seconds=60)) assert throttle.check() throttle.mark() assert not throttle.check()