From d21b3d17ed0a8677ad03484f15e3cad1a553f679 Mon Sep 17 00:00:00 2001 From: hawang-wish <130547790+hawang-wish@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:15:55 +0800 Subject: [PATCH 1/2] Invert execution orders of post, on_error and shutdown middleware hooks --- taskiq/abc/broker.py | 2 +- taskiq/kicker.py | 2 +- taskiq/receiver/receiver.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index b28e811e..81adb306 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -213,7 +213,7 @@ async def shutdown(self) -> None: for handler in self.event_handlers[event]: await maybe_awaitable(handler(self.state)) - for middleware in self.middlewares: + for middleware in reversed(self.middlewares): if middleware.__class__.shutdown != TaskiqMiddleware.shutdown: await maybe_awaitable(middleware.shutdown) diff --git a/taskiq/kicker.py b/taskiq/kicker.py index 889dee00..bcd3f7e3 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -139,7 +139,7 @@ async def kiq( except Exception as exc: raise SendTaskError from exc - for middleware in self.broker.middlewares: + for middleware in reversed(self.broker.middlewares): if middleware.__class__.post_send != TaskiqMiddleware.post_send: await maybe_awaitable(middleware.post_send(message)) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 7d5a4035..a4d1ac9f 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -156,7 +156,7 @@ async def callback( # noqa: C901, PLR0912 ): await maybe_awaitable(message.ack()) - for middleware in self.broker.middlewares: + for middleware in reversed(self.broker.middlewares): if middleware.__class__.post_execute != TaskiqMiddleware.post_execute: await maybe_awaitable(middleware.post_execute(taskiq_msg, result)) @@ -164,7 +164,7 @@ async def callback( # noqa: C901, PLR0912 if not isinstance(result.error, NoResultError): await self.broker.result_backend.set_result(taskiq_msg.task_id, result) - for middleware in self.broker.middlewares: + for middleware in reversed(self.broker.middlewares): if middleware.__class__.post_save != TaskiqMiddleware.post_save: await maybe_awaitable(middleware.post_save(taskiq_msg, result)) @@ -306,7 +306,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 ) # If exception is found we execute middlewares. if found_exception is not None: - for middleware in self.broker.middlewares: + for middleware in reversed(self.broker.middlewares): if middleware.__class__.on_error != TaskiqMiddleware.on_error: await maybe_awaitable( middleware.on_error( From 5c59fb4714143ef0e1b2f61e2012a20636461e0f Mon Sep 17 00:00:00 2001 From: hawang-wish <130547790+hawang-wish@users.noreply.github.com> Date: Thu, 29 Aug 2024 14:27:57 +0800 Subject: [PATCH 2/2] Add tests --- tests/middlewares/test_hooks.py | 156 ++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 tests/middlewares/test_hooks.py diff --git a/tests/middlewares/test_hooks.py b/tests/middlewares/test_hooks.py new file mode 100644 index 00000000..fe9ffd1d --- /dev/null +++ b/tests/middlewares/test_hooks.py @@ -0,0 +1,156 @@ +import asyncio +from typing import Any + +import pytest + +from taskiq.abc.broker import AsyncBroker +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.message import TaskiqMessage +from taskiq.result import TaskiqResult + + +@pytest.mark.anyio +async def test_set_broker() -> None: + + class _TestMiddleware(TaskiqMiddleware): + def set_broker(self, broker: "AsyncBroker") -> None: + super().set_broker(broker) + self.test_value = 1 + + middleware = _TestMiddleware() + broker = InMemoryBroker().with_middlewares(middleware) + + assert middleware is broker.middlewares[0] + assert middleware.test_value == 1 + + +@pytest.mark.anyio +async def test_startup_shutdown_in_pair() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def startup(self) -> None: + test_list.append("1up") + + def shutdown(self) -> None: + test_list.append("1down") + + class _TestMiddleware2(TaskiqMiddleware): + async def startup(self) -> None: + await asyncio.sleep(0) + test_list.append("2up") + + async def shutdown(self) -> None: + await asyncio.sleep(0) + test_list.append("2down") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + await broker.shutdown() + + assert test_list == ["1up", "2up", "2down", "1down"] + + +@pytest.mark.anyio +async def test_pre_post_send_in_pair() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def pre_send(self, message: "TaskiqMessage") -> "TaskiqMessage": + test_list.append("1pre") + return message + + def post_send(self, message: "TaskiqMessage") -> None: + test_list.append("1post") + + class _TestMiddleware2(TaskiqMiddleware): + def pre_send(self, message: "TaskiqMessage") -> "TaskiqMessage": + test_list.append("2pre") + return message + + def post_send(self, message: "TaskiqMessage") -> None: + test_list.append("2post") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + await broker.task(lambda: None).kiq() + await broker.shutdown() + + assert test_list == ["1pre", "2pre", "2post", "1post"] + + +@pytest.mark.anyio +async def test_pre_post_execute_in_pair() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def pre_execute(self, message: "TaskiqMessage") -> "TaskiqMessage": + test_list.append("1pre") + return message + + def post_execute(self, message: "TaskiqMessage", result: "TaskiqResult[Any]") -> None: + test_list.append("1post") + + class _TestMiddleware2(TaskiqMiddleware): + def pre_execute(self, message: "TaskiqMessage") -> "TaskiqMessage": + test_list.append("2pre") + return message + + def post_execute(self, message: "TaskiqMessage", result: "TaskiqResult[Any]") -> None: + test_list.append("2post") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + task = await broker.task(lambda: 1).kiq() + await task.wait_result(timeout=2) + await broker.shutdown() + + assert test_list == ["1pre", "2pre", "2post", "1post"] + + +@pytest.mark.anyio +async def test_post_save_inverted() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def post_save(self, message: "TaskiqMessage", result: "TaskiqResult[Any]") -> None: + test_list.append("1save") + + class _TestMiddleware2(TaskiqMiddleware): + def post_save(self, message: "TaskiqMessage", result: "TaskiqResult[Any]") -> None: + test_list.append("2save") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + task = await broker.task(lambda: 1).kiq() + await task.wait_result(timeout=2) + await broker.shutdown() + + assert test_list == ["2save", "1save"] + + +@pytest.mark.anyio +async def test_post_on_error_inverted() -> None: + test_list = [] + + class _TestMiddleware1(TaskiqMiddleware): + def on_error(self, message: "TaskiqMessage", result: "TaskiqResult[Any]", exception: BaseException) -> None: + test_list.append("1error") + + class _TestMiddleware2(TaskiqMiddleware): + def on_error(self, message: "TaskiqMessage", result: "TaskiqResult[Any]", exception: BaseException) -> None: + test_list.append("2error") + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + await broker.startup() + task = await broker.task(lambda: (_ for _ in ()).throw(Exception("test"))).kiq() + await task.wait_result(timeout=2) + await broker.shutdown() + + assert test_list == ["2error", "1error"]