diff --git a/docs/examples/extending/broker.py b/docs/examples/extending/broker.py index 026210d2..628a0ee6 100644 --- a/docs/examples/extending/broker.py +++ b/docs/examples/extending/broker.py @@ -1,6 +1,6 @@ from typing import AsyncGenerator, Union -from taskiq import AckableMessage, AsyncBroker, BrokerMessage +from taskiq import WrappedMessage, AsyncBroker, BrokerMessage class MyBroker(AsyncBroker): @@ -23,7 +23,7 @@ async def kick(self, message: BrokerMessage) -> None: # Send a message.message. pass - async def listen(self) -> AsyncGenerator[Union[bytes, AckableMessage], None]: + async def listen(self) -> AsyncGenerator[Union[bytes, WrappedMessage], None]: while True: # Get new message. new_message: bytes = ... # type: ignore diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 87a5c4f1..fbe1c2a9 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -8,7 +8,6 @@ from taskiq.abc.middleware import TaskiqMiddleware from taskiq.abc.result_backend import AsyncResultBackend from taskiq.abc.schedule_source import ScheduleSource -from taskiq.acks import AckableMessage from taskiq.brokers.inmemory_broker import InMemoryBroker from taskiq.brokers.shared_broker import async_shared_broker from taskiq.brokers.zmq_broker import ZeroMQBroker @@ -24,7 +23,7 @@ TaskiqResultTimeoutError, ) from taskiq.funcs import gather -from taskiq.message import BrokerMessage, TaskiqMessage +from taskiq.message import BrokerMessage, MessageMetadata, TaskiqMessage, WrappedMessage from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware from taskiq.result import TaskiqResult @@ -35,7 +34,6 @@ __version__ = version("taskiq") __all__ = [ - "AckableMessage", "AsyncBroker", "AsyncResultBackend", "AsyncTaskiqDecoratedTask", @@ -43,6 +41,7 @@ "BrokerMessage", "Context", "InMemoryBroker", + "MessageMetadata", "NoResultError", "PrometheusMiddleware", "ResultGetError", @@ -62,6 +61,7 @@ "TaskiqResultTimeoutError", "TaskiqScheduler", "TaskiqState", + "WrappedMessage", "ZeroMQBroker", "__version__", "async_shared_broker", diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 69c9e25c..8df9b43c 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -26,11 +26,10 @@ from taskiq.abc.middleware import TaskiqMiddleware from taskiq.abc.serializer import TaskiqSerializer -from taskiq.acks import AckableMessage from taskiq.decor import AsyncTaskiqDecoratedTask from taskiq.events import TaskiqEvents from taskiq.formatters.proxy_formatter import ProxyFormatter -from taskiq.message import BrokerMessage +from taskiq.message import BrokerMessage, WrappedMessage from taskiq.result_backends.dummy import DummyResultBackend from taskiq.serializers.json_serializer import JSONSerializer from taskiq.state import TaskiqState @@ -77,6 +76,7 @@ def __init__( self, result_backend: "Optional[AsyncResultBackend[_T]]" = None, task_id_generator: Optional[Callable[[], str]] = None, + max_attempts_at_message: Optional[int] = None, ) -> None: if result_backend is None: result_backend = DummyResultBackend() @@ -113,6 +113,7 @@ def __init__( self.state = TaskiqState() self.custom_dependency_context: Dict[Any, Any] = {} self.dependency_overrides: Dict[Any, Any] = {} + self.max_attempts_at_message = max_attempts_at_message # True only if broker runs in worker process. self.is_worker_process = False # True only if broker runs in scheduler process. @@ -237,18 +238,20 @@ async def kick( """ @abstractmethod - def listen(self) -> AsyncGenerator[Union[bytes, AckableMessage], None]: + def listen(self) -> AsyncGenerator[Union[bytes, WrappedMessage], None]: """ This function listens to new messages and yields them. This it the main point for workers. This function is used to get new tasks from the network. - If your broker support acknowledgement, then you - should wrap your message in AckableMessage dataclass. + If your broker support acknowledgements (or negative acknowledgements), + then the returned message should implement the AckableMessage + (or NackableMessage) interface by implementing the `ack` (or + `nack`) callback. - If your messages was wrapped in AckableMessage dataclass, - taskiq will call ack when finish processing message. + If your message has an `ack` callbacks it will be called after the + message is processed. :yield: incoming messages. :return: nothing. diff --git a/taskiq/acks.py b/taskiq/acks.py index c3b3fe77..3722319d 100644 --- a/taskiq/acks.py +++ b/taskiq/acks.py @@ -33,5 +33,15 @@ class AckableMessage(BaseModel): as a whole. """ - data: bytes ack: Callable[[], Union[None, Awaitable[None]]] + + +class NackableMessage(BaseModel): + """ + Message that can be negatively acknowledged. + + Message that can be negatively acknowledged, e.g. + sent to a dead-letter queue, etc. + """ + + nack: Callable[[], Union[None, Awaitable[None]]] diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index 4977eb39..ee3ffff3 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -148,6 +148,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None: ack_type=args.ack_type, max_tasks_to_execute=args.max_tasks_per_child, wait_tasks_timeout=args.wait_tasks_timeout, + max_attempts_at_message=broker.max_attempts_at_message, **receiver_kwargs, # type: ignore ) loop.run_until_complete(receiver.listen(shutdown_event)) diff --git a/taskiq/message.py b/taskiq/message.py index 675f7cf3..50ea8a03 100644 --- a/taskiq/message.py +++ b/taskiq/message.py @@ -2,6 +2,7 @@ from pydantic import BaseModel +from taskiq.acks import AckableMessage, NackableMessage from taskiq.labels import parse_label @@ -42,3 +43,42 @@ class BrokerMessage(BaseModel): task_name: str message: bytes labels: Dict[str, Any] + + +class MessageMetadata(BaseModel): + """Incoming message metadata.""" + + delivery_count: Optional[int] = None + + +class WrappedMessage(BaseModel): # noqa: D101 + message: bytes + + +class MessageWithMetadata(BaseModel): # noqa: D101 + metadata: MessageMetadata + + +class WrappedMessageWithMetadata(WrappedMessage, MessageWithMetadata): # noqa: D101 + ... + + +class AckableWrappedMessage(WrappedMessage, AckableMessage): # noqa: D101 + ... + + +class AckableWrappedMessageWithMetadata( # noqa: D101 + WrappedMessage, + AckableMessage, + MessageWithMetadata, +): + ... + + +class AckableNackableWrappedMessageWithMetadata( # noqa: D101 + WrappedMessage, + AckableMessage, + NackableMessage, + MessageWithMetadata, +): + ... diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 41e687a6..b0b49b1d 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -8,12 +8,12 @@ import anyio from taskiq_dependencies import DependencyGraph -from taskiq.abc.broker import AckableMessage, AsyncBroker +from taskiq.abc.broker import AsyncBroker from taskiq.abc.middleware import TaskiqMiddleware -from taskiq.acks import AcknowledgeType +from taskiq.acks import AckableMessage, AcknowledgeType, NackableMessage from taskiq.context import Context from taskiq.exceptions import NoResultError -from taskiq.message import TaskiqMessage +from taskiq.message import MessageWithMetadata, TaskiqMessage, WrappedMessage from taskiq.receiver.params_parser import parse_params from taskiq.result import TaskiqResult from taskiq.state import TaskiqState @@ -58,6 +58,7 @@ def __init__( on_exit: Optional[Callable[["Receiver"], None]] = None, max_tasks_to_execute: Optional[int] = None, wait_tasks_timeout: Optional[float] = None, + max_attempts_at_message: Optional[int] = None, ) -> None: self.broker = broker self.executor = executor @@ -72,6 +73,7 @@ def __init__( self.known_tasks: Set[str] = set() self.max_tasks_to_execute = max_tasks_to_execute self.wait_tasks_timeout = wait_tasks_timeout + self.max_attempts_at_message = max_attempts_at_message for task in self.broker.get_all_tasks().values(): self._prepare_task(task.task_name, task.original_func) self.sem: "Optional[asyncio.Semaphore]" = None @@ -86,7 +88,7 @@ def __init__( async def callback( # noqa: C901, PLR0912 self, - message: Union[bytes, AckableMessage], + message: Union[bytes, WrappedMessage], raise_err: bool = False, ) -> None: """ @@ -101,7 +103,33 @@ async def callback( # noqa: C901, PLR0912 :param raise_err: raise an error if cannot save result in result_backend. """ - message_data = message.data if isinstance(message, AckableMessage) else message + message_data = ( + message.message if isinstance(message, WrappedMessage) else message + ) + if isinstance(message, MessageWithMetadata): + message_metadata = message.metadata + else: + message_metadata = None + + delivery_count = message_metadata.delivery_count if message_metadata else None + if ( + delivery_count + and self.max_attempts_at_message + and delivery_count >= self.max_attempts_at_message + ): + logger.error( + "Permitted number of attempts at processing message %s " + "has been exhausted after %s attempts.", + message_data, + self.max_attempts_at_message, + ) + match message: + case NackableMessage(): + await maybe_awaitable(message.nack()) + case AckableMessage(): + await maybe_awaitable(message.ack()) + return + try: taskiq_msg = self.broker.formatter.loads(message=message_data) taskiq_msg.parse_labels() @@ -331,7 +359,7 @@ async def listen(self, finish_event: asyncio.Event) -> None: # pragma: no cover if self.run_startup: await self.broker.startup() logger.info("Listening started.") - queue: "asyncio.Queue[Union[bytes, AckableMessage]]" = asyncio.Queue() + queue: "asyncio.Queue[Union[bytes, WrappedMessage]]" = asyncio.Queue() async with anyio.create_task_group() as gr: gr.start_soon(self.prefetcher, queue, finish_event) @@ -342,7 +370,7 @@ async def listen(self, finish_event: asyncio.Event) -> None: # pragma: no cover async def prefetcher( self, - queue: "asyncio.Queue[Union[bytes, AckableMessage]]", + queue: "asyncio.Queue[Union[bytes, WrappedMessage]]", finish_event: asyncio.Event, ) -> None: """ @@ -354,7 +382,7 @@ async def prefetcher( fetched_tasks: int = 0 iterator = self.broker.listen() current_message: asyncio.Task[ - Union[bytes, AckableMessage] + Union[bytes, WrappedMessage] ] = asyncio.create_task( iterator.__anext__(), # type: ignore ) @@ -394,7 +422,7 @@ async def prefetcher( async def runner( self, - queue: "asyncio.Queue[Union[bytes, AckableMessage]]", + queue: "asyncio.Queue[Union[bytes, WrappedMessage]]", ) -> None: """ Run tasks. diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 5a65efff..a8b2e79f 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -7,11 +7,18 @@ import pytest from taskiq_dependencies import Depends -from taskiq.abc.broker import AckableMessage, AsyncBroker +from taskiq.abc.broker import AsyncBroker from taskiq.abc.middleware import TaskiqMiddleware from taskiq.brokers.inmemory_broker import InMemoryBroker from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError -from taskiq.message import TaskiqMessage +from taskiq.message import ( + AckableNackableWrappedMessageWithMetadata, + AckableWrappedMessage, + AckableWrappedMessageWithMetadata, + MessageMetadata, + TaskiqMessage, + WrappedMessageWithMetadata, +) from taskiq.receiver import Receiver from taskiq.result import TaskiqResult from tests.utils import AsyncQueueBroker @@ -282,8 +289,8 @@ def ack_callback() -> None: ) await receiver.callback( - AckableMessage( - data=broker_message.message, + AckableWrappedMessage( + message=broker_message.message, ack=ack_callback, ), ) @@ -321,8 +328,8 @@ async def ack_callback() -> None: ) await receiver.callback( - AckableMessage( - data=broker_message.message, + AckableWrappedMessage( + message=broker_message.message, ack=ack_callback, ), ) @@ -359,6 +366,170 @@ async def test_callback_unknown_task() -> None: await receiver.callback(broker_message.message) +@pytest.mark.anyio +@pytest.mark.parametrize("delivery_count", [2, None]) +async def test_callback_max_attempts_at_message_not_exceeded( + delivery_count: Optional[int], +) -> None: + broker = InMemoryBroker() + called_times = 0 + + @broker.task + async def my_task() -> int: + nonlocal called_times + called_times += 1 + return 1 + + receiver = get_receiver(broker) + receiver.max_attempts_at_message = 3 + + broker_message = broker.formatter.dumps( + TaskiqMessage( + task_id="task_id", + task_name=my_task.task_name, + labels={}, + args=[], + kwargs={}, + ), + ) + + await receiver.callback( + WrappedMessageWithMetadata( + message=broker_message.message, + metadata=MessageMetadata( + delivery_count=delivery_count, + ), + ), + ) + assert called_times == 1 + + +@pytest.mark.anyio +async def test_callback_max_attempts_at_message_exceeded() -> None: + broker = InMemoryBroker() + called_times = 0 + + @broker.task + async def my_task() -> int: + nonlocal called_times + called_times += 1 + return 1 + + receiver = get_receiver(broker) + receiver.max_attempts_at_message = 3 + + broker_message = broker.formatter.dumps( + TaskiqMessage( + task_id="task_id", + task_name=my_task.task_name, + labels={}, + args=[], + kwargs={}, + ), + ) + + await receiver.callback( + WrappedMessageWithMetadata( + message=broker_message.message, + metadata=MessageMetadata( + delivery_count=3, + ), + ), + ) + assert called_times == 0 + + +@pytest.mark.anyio +async def test_callback_max_attempts_at_message_exceeded_ackable() -> None: + broker = InMemoryBroker() + called_times = 0 + acked = False + + @broker.task + async def my_task() -> int: + nonlocal called_times + called_times += 1 + return 1 + + async def ack_callback() -> None: + nonlocal acked + acked = True + + receiver = get_receiver(broker) + receiver.max_attempts_at_message = 3 + + broker_message = broker.formatter.dumps( + TaskiqMessage( + task_id="task_id", + task_name=my_task.task_name, + labels={}, + args=[], + kwargs={}, + ), + ) + + await receiver.callback( + AckableWrappedMessageWithMetadata( + message=broker_message.message, + metadata=MessageMetadata( + delivery_count=3, + ), + ack=ack_callback, + ), + ) + assert called_times == 0 + assert acked + + +@pytest.mark.anyio +async def test_callback_max_attempts_at_message_exceeded_nackable() -> None: + broker = InMemoryBroker() + called_times = 0 + acked = False + nacked = False + + @broker.task + async def my_task() -> int: + nonlocal called_times + called_times += 1 + return 1 + + async def ack_callback() -> None: + nonlocal acked + acked = True + + async def nack_callback() -> None: + nonlocal nacked + nacked = True + + receiver = get_receiver(broker) + receiver.max_attempts_at_message = 3 + + broker_message = broker.formatter.dumps( + TaskiqMessage( + task_id="task_id", + task_name=my_task.task_name, + labels={}, + args=[], + kwargs={}, + ), + ) + + await receiver.callback( + AckableNackableWrappedMessageWithMetadata( + message=broker_message.message, + metadata=MessageMetadata( + delivery_count=3, + ), + ack=ack_callback, + nack=nack_callback, + ), + ) + assert called_times == 0 + assert not acked + assert nacked + + @pytest.mark.anyio async def test_custom_ctx() -> None: """Tests that run_task can run sync tasks.""" diff --git a/tests/utils.py b/tests/utils.py index 67603808..c366fe29 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,7 +2,7 @@ from typing import AsyncGenerator from taskiq import AsyncBroker, BrokerMessage -from taskiq.acks import AckableMessage +from taskiq.message import AckableWrappedMessage class AsyncQueueBroker(AsyncBroker): @@ -25,8 +25,8 @@ async def wait_tasks(self) -> None: """Small method to wait for all tasks to be processed.""" await self.queue.join() - async def listen(self) -> AsyncGenerator[AckableMessage, None]: + async def listen(self) -> AsyncGenerator[AckableWrappedMessage, None]: """This method returns all tasks from queue.""" while True: task = await self.queue.get() - yield AckableMessage(data=task, ack=self.queue.task_done) + yield AckableWrappedMessage(message=task, ack=self.queue.task_done)