From 5a636b4f174be7b6756f5de072ac258b578b1c05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20B=C3=B6hm?= Date: Thu, 1 Aug 2024 20:11:02 +0200 Subject: [PATCH] refactor: Replace messages generator with iterator class that implements len() --- aiomqtt/client.py | 70 ++++++++++++++++++++++++-------------------- tests/test_client.py | 4 ++- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/aiomqtt/client.py b/aiomqtt/client.py index 4b4439e..cf551d4 100644 --- a/aiomqtt/client.py +++ b/aiomqtt/client.py @@ -14,7 +14,7 @@ from types import TracebackType from typing import ( Any, - AsyncGenerator, + AsyncIterator, Awaitable, Callable, Coroutine, @@ -125,7 +125,7 @@ class Will: class Client: - """The async context manager that manages the connection to the broker. + """Asynchronous context manager for the connection to the MQTT broker. Args: hostname: The hostname or IP address of the remote broker. @@ -320,10 +320,6 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 timeout = 10 self.timeout = timeout - @property - def messages(self) -> AsyncGenerator[Message, None]: - return self._messages() - @property def identifier(self) -> str: """Return the client identifier. @@ -333,6 +329,42 @@ def identifier(self) -> str: """ return self._client._client_id.decode() # noqa: SLF001 + class MessagesIterator: + """Dynamic view of the message queue.""" + + def __init__(self, client: Client) -> None: + self._client = client + + def __aiter__(self) -> AsyncIterator[Message]: + return self + + async def __anext__(self) -> Message: + # Wait until we either (1) receive a message or (2) disconnect + task = self._client._loop.create_task(self._client._queue.get()) # noqa: SLF001 + try: + done, _ = await asyncio.wait( + (task, self._client._disconnected), # noqa: SLF001 + return_when=asyncio.FIRST_COMPLETED, + ) + # If the asyncio.wait is cancelled, we must also cancel the queue task + except asyncio.CancelledError: + task.cancel() + raise + # When we receive a message, return it + if task in done: + return task.result() + # If we disconnect from the broker, stop the generator with an exception + task.cancel() + msg = "Disconnected during message iteration" + raise MqttError(msg) + + def __len__(self) -> int: + return self._client._queue.qsize() # noqa: SLF001 + + @property + def messages(self) -> MessagesIterator: + return self.MessagesIterator(self) + @property def _pending_calls(self) -> Generator[int, None, None]: """Yield all message IDs with pending calls.""" @@ -456,32 +488,6 @@ async def publish( # noqa: PLR0913 # Wait for confirmation await self._wait_for(confirmation.wait(), timeout=timeout) - async def _messages(self) -> AsyncGenerator[Message, None]: - """Async generator that yields messages from the underlying message queue.""" - while True: - # Wait until we either: - # 1. Receive a message - # 2. Disconnect from the broker - task = self._loop.create_task(self._queue.get()) - try: - done, _ = await asyncio.wait( - (task, self._disconnected), return_when=asyncio.FIRST_COMPLETED - ) - except asyncio.CancelledError: - # If the asyncio.wait is cancelled, we must make sure - # to also cancel the underlying tasks. - task.cancel() - raise - if task in done: - # We received a message. Return the result. - yield task.result() - else: - # We were disconnected from the broker - task.cancel() - # Stop the generator with an exception - msg = "Disconnected during message iteration" - raise MqttError(msg) - async def _wait_for( self, fut: Awaitable[T], timeout: float | None, **kwargs: Any ) -> T: diff --git a/tests/test_client.py b/tests/test_client.py index 6ffc877..1a3b544 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,11 @@ from __future__ import annotations +import asyncio import logging import pathlib import ssl import sys +from typing import Any import anyio import anyio.abc @@ -413,7 +415,7 @@ async def test_messages_view_is_reusable() -> None: @pytest.mark.network async def test_messages_view_multiple_tasks_concurrently() -> None: """Test that ``.messages`` can be used concurrently by multiple tasks.""" - topic = TOPIC_PREFIX + "test_messages_generator_is_reentrant" + topic = TOPIC_PREFIX + "test_messages_view_multiple_tasks_concurrently" async with Client(HOSTNAME) as client, anyio.create_task_group() as tg: async def handle() -> None: