Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions taskiq/middlewares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from .prometheus_middleware import PrometheusMiddleware
from .simple_retry_middleware import SimpleRetryMiddleware
from .smart_retry_middleware import SmartRetryMiddleware
from .taskiq_admin_middleware import TaskiqAdminMiddleware

__all__ = (
"PrometheusMiddleware",
"SimpleRetryMiddleware",
"SmartRetryMiddleware",
"TaskiqAdminMiddleware",
)
6 changes: 5 additions & 1 deletion taskiq/middlewares/taskiq_admin_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import aiohttp

from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.message import TaskiqMessage
from taskiq.result import TaskiqResult

__all__ = ("TaskiqAdminMiddleware",)

Expand Down Expand Up @@ -118,6 +120,7 @@ async def post_send(self, message: TaskiqMessage) -> None:
{
"args": message.args,
"kwargs": message.kwargs,
"labels": message.labels,
"queuedAt": self._now_iso(),
"taskName": message.task_name,
"worker": self.__ta_broker_name,
Expand All @@ -139,6 +142,7 @@ async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
{
"args": message.args,
"kwargs": message.kwargs,
"labels": message.labels,
"startedAt": self._now_iso(),
"taskName": message.task_name,
"worker": self.__ta_broker_name,
Expand Down
22 changes: 18 additions & 4 deletions tests/middlewares/test_taskiq_admin_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import datetime
from typing import AsyncGenerator
from unittest.mock import AsyncMock, Mock, patch

Expand Down Expand Up @@ -26,7 +27,19 @@ def message() -> TaskiqMessage:
return TaskiqMessage(
task_id="task-123",
task_name="test_task",
labels={},
labels={
"schedule": {
"cron": "*/1 * * * *",
"cron_offset": datetime.timedelta(hours=1),
"time": datetime.datetime.now(datetime.timezone.utc),
"labels": {
"test_bool": True,
"test_int": 1,
"test_str": "str",
"test_bytes": b"bytes",
},
},
},
args=[1, 2, 3],
kwargs={"key": "value"},
)
Expand Down Expand Up @@ -80,8 +93,9 @@ async def test_when_post_send_is_called__then_payload_includes_task_info(
call_args = mock_post.call_args
assert call_args is not None
payload = call_args[1]["json"]
assert payload["args"] == [1, 2, 3]
assert payload["kwargs"] == {"key": "value"}
assert payload["taskName"] == "test_task"
assert payload["args"] == message.args
assert payload["kwargs"] == message.kwargs
assert payload["taskName"] == message.task_name
assert payload["worker"] == "test-broker"
assert payload["labels"] == message.labels
assert "queuedAt" in payload