diff --git a/taskiq/middlewares/__init__.py b/taskiq/middlewares/__init__.py index 18a9c50..2d36ae1 100644 --- a/taskiq/middlewares/__init__.py +++ b/taskiq/middlewares/__init__.py @@ -3,9 +3,11 @@ from .prometheus_middleware import PrometheusMiddleware from .simple_retry_middleware import SimpleRetryMiddleware from .smart_retry_middleware import SmartRetryMiddleware +from .taskiq_admin_middleware import TaskiqAdminMiddleware __all__ = ( "PrometheusMiddleware", "SimpleRetryMiddleware", "SmartRetryMiddleware", + "TaskiqAdminMiddleware", ) diff --git a/taskiq/middlewares/taskiq_admin_middleware.py b/taskiq/middlewares/taskiq_admin_middleware.py index 659e78c..7d1918d 100644 --- a/taskiq/middlewares/taskiq_admin_middleware.py +++ b/taskiq/middlewares/taskiq_admin_middleware.py @@ -6,7 +6,9 @@ import aiohttp -from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.message import TaskiqMessage +from taskiq.result import TaskiqResult __all__ = ("TaskiqAdminMiddleware",) @@ -118,6 +120,7 @@ async def post_send(self, message: TaskiqMessage) -> None: { "args": message.args, "kwargs": message.kwargs, + "labels": message.labels, "queuedAt": self._now_iso(), "taskName": message.task_name, "worker": self.__ta_broker_name, @@ -139,6 +142,7 @@ async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: { "args": message.args, "kwargs": message.kwargs, + "labels": message.labels, "startedAt": self._now_iso(), "taskName": message.task_name, "worker": self.__ta_broker_name, diff --git a/tests/middlewares/test_taskiq_admin_middleware.py b/tests/middlewares/test_taskiq_admin_middleware.py index bc7c0e3..68bbbac 100644 --- a/tests/middlewares/test_taskiq_admin_middleware.py +++ b/tests/middlewares/test_taskiq_admin_middleware.py @@ -1,4 +1,5 @@ import asyncio +import datetime from typing import AsyncGenerator from unittest.mock import AsyncMock, Mock, patch @@ -26,7 +27,19 @@ def message() -> TaskiqMessage: return TaskiqMessage( task_id="task-123", task_name="test_task", - labels={}, + labels={ + "schedule": { + "cron": "*/1 * * * *", + "cron_offset": datetime.timedelta(hours=1), + "time": datetime.datetime.now(datetime.timezone.utc), + "labels": { + "test_bool": True, + "test_int": 1, + "test_str": "str", + "test_bytes": b"bytes", + }, + }, + }, args=[1, 2, 3], kwargs={"key": "value"}, ) @@ -80,8 +93,9 @@ async def test_when_post_send_is_called__then_payload_includes_task_info( call_args = mock_post.call_args assert call_args is not None payload = call_args[1]["json"] - assert payload["args"] == [1, 2, 3] - assert payload["kwargs"] == {"key": "value"} - assert payload["taskName"] == "test_task" + assert payload["args"] == message.args + assert payload["kwargs"] == message.kwargs + assert payload["taskName"] == message.task_name assert payload["worker"] == "test-broker" + assert payload["labels"] == message.labels assert "queuedAt" in payload