diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 820d2c53..7a9938c2 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -213,7 +213,7 @@ async def shutdown(self) -> None: for handler in self.event_handlers[event]: await maybe_awaitable(handler(self.state)) - for middleware in self.middlewares: + for middleware in reversed(self.middlewares): if middleware.__class__.shutdown != TaskiqMiddleware.shutdown: await maybe_awaitable(middleware.shutdown()) diff --git a/taskiq/kicker.py b/taskiq/kicker.py index beff1c82..9baff544 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -153,7 +153,7 @@ async def kiq( except Exception as exc: raise SendTaskError from exc - for middleware in self.broker.middlewares: + for middleware in reversed(self.broker.middlewares): if middleware.__class__.post_send != TaskiqMiddleware.post_send: await maybe_awaitable(middleware.post_send(message)) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index ad078c39..f42f952d 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -157,7 +157,7 @@ async def callback( # noqa: C901, PLR0912 ): await maybe_awaitable(message.ack()) - for middleware in self.broker.middlewares: + for middleware in reversed(self.broker.middlewares): if middleware.__class__.post_execute != TaskiqMiddleware.post_execute: await maybe_awaitable(middleware.post_execute(taskiq_msg, result)) @@ -165,7 +165,7 @@ async def callback( # noqa: C901, PLR0912 if not isinstance(result.error, NoResultError): await self.broker.result_backend.set_result(taskiq_msg.task_id, result) - for middleware in self.broker.middlewares: + for middleware in reversed(self.broker.middlewares): if middleware.__class__.post_save != TaskiqMiddleware.post_save: await maybe_awaitable(middleware.post_save(taskiq_msg, result)) @@ -307,7 +307,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 ) # If exception is found we execute middlewares. if found_exception is not None: - for middleware in self.broker.middlewares: + for middleware in reversed(self.broker.middlewares): if middleware.__class__.on_error != TaskiqMiddleware.on_error: await maybe_awaitable( middleware.on_error( diff --git a/tests/middlewares/test_hooks.py b/tests/middlewares/test_hooks.py new file mode 100644 index 00000000..fe9ffd1d --- /dev/null +++ b/tests/middlewares/test_hooks.py @@ -0,0 +1,156 @@ +import asyncio +from typing import Any + +import pytest + +from taskiq.abc.broker import AsyncBroker +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.message import TaskiqMessage +from taskiq.result import TaskiqResult + + +@pytest.mark.anyio +async def test_set_broker() -> None: + + class _TestMiddleware(TaskiqMiddleware): + def set_broker(self, broker: "AsyncBroker") -> None: + super().set_broker(broker) + self.test_value = 1 + + middleware = _TestMiddleware() + broker = InMemoryBroker().with_middlewares(middleware) + + assert middleware is broker.middlewares[0] + assert middleware.test_value == 1 + + +@pytest.mark.anyio +async def test_startup_shutdown_in_pair() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def startup(self) -> None: + test_list.append("1up") + + def shutdown(self) -> None: + test_list.append("1down") + + class _TestMiddleware2(TaskiqMiddleware): + async def startup(self) -> None: + await asyncio.sleep(0) + test_list.append("2up") + + async def shutdown(self) -> None: + await asyncio.sleep(0) + test_list.append("2down") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + await broker.shutdown() + + assert test_list == ["1up", "2up", "2down", "1down"] + + +@pytest.mark.anyio +async def test_pre_post_send_in_pair() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def pre_send(self, message: "TaskiqMessage") -> "TaskiqMessage": + test_list.append("1pre") + return message + + def post_send(self, message: "TaskiqMessage") -> None: + test_list.append("1post") + + class _TestMiddleware2(TaskiqMiddleware): + def pre_send(self, message: "TaskiqMessage") -> "TaskiqMessage": + test_list.append("2pre") + return message + + def post_send(self, message: "TaskiqMessage") -> None: + test_list.append("2post") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + await broker.task(lambda: None).kiq() + await broker.shutdown() + + assert test_list == ["1pre", "2pre", "2post", "1post"] + + +@pytest.mark.anyio +async def test_pre_post_execute_in_pair() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def pre_execute(self, message: "TaskiqMessage") -> "TaskiqMessage": + test_list.append("1pre") + return message + + def post_execute(self, message: "TaskiqMessage", result: "TaskiqResult[Any]") -> None: + test_list.append("1post") + + class _TestMiddleware2(TaskiqMiddleware): + def pre_execute(self, message: "TaskiqMessage") -> "TaskiqMessage": + test_list.append("2pre") + return message + + def post_execute(self, message: "TaskiqMessage", result: "TaskiqResult[Any]") -> None: + test_list.append("2post") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + task = await broker.task(lambda: 1).kiq() + await task.wait_result(timeout=2) + await broker.shutdown() + + assert test_list == ["1pre", "2pre", "2post", "1post"] + + +@pytest.mark.anyio +async def test_post_save_inverted() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def post_save(self, message: "TaskiqMessage", result: "TaskiqResult[Any]") -> None: + test_list.append("1save") + + class _TestMiddleware2(TaskiqMiddleware): + def post_save(self, message: "TaskiqMessage", result: "TaskiqResult[Any]") -> None: + test_list.append("2save") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + task = await broker.task(lambda: 1).kiq() + await task.wait_result(timeout=2) + await broker.shutdown() + + assert test_list == ["2save", "1save"] + + +@pytest.mark.anyio +async def test_post_on_error_inverted() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def on_error(self, message: "TaskiqMessage", result: "TaskiqResult[Any]", exception: BaseException) -> None: + test_list.append("1error") + + class _TestMiddleware2(TaskiqMiddleware): + def on_error(self, message: "TaskiqMessage", result: "TaskiqResult[Any]", exception: BaseException) -> None: + test_list.append("2error") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + task = await broker.task(lambda: (_ for _ in ()).throw(Exception("test"))).kiq() + await task.wait_result(timeout=2) + await broker.shutdown() + + assert test_list == ["2error", "1error"]