Skip to content

Commit 35ae965

Browse files
committed
fix: typing/linter errors, update middleware
1 parent cc85f6a commit 35ae965

File tree

1 file changed

+65
-18
lines changed

1 file changed

+65
-18
lines changed

taskiq/middlewares/taskiq_admin_middleware.py

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
from datetime import UTC, datetime
33
from logging import getLogger
4-
from typing import Any
4+
from types import CoroutineType
5+
from typing import Any, Coroutine, Self, Union
56
from urllib.parse import urljoin
67

78
import aiohttp
@@ -37,7 +38,7 @@ def __init__(
3738
api_token: str,
3839
timeout: int = 5,
3940
taskiq_broker_name: str | None = None,
40-
):
41+
) -> None:
4142
super().__init__()
4243
self.url = url
4344
self.timeout = timeout
@@ -50,22 +51,28 @@ def __init__(
5051
def _now_iso() -> str:
5152
return datetime.now(UTC).replace(tzinfo=None).isoformat()
5253

53-
async def startup(self):
54-
self._client = aiohttp.ClientSession(
55-
timeout=aiohttp.ClientTimeout(total=self.timeout),
56-
)
54+
def _get_session(self: Self) -> aiohttp.ClientSession:
55+
"""Create and cache session."""
56+
if self._client is None or self._client.closed:
57+
self._client = aiohttp.ClientSession(
58+
timeout=aiohttp.ClientTimeout(total=self.timeout),
59+
)
60+
61+
return self._client
5762

58-
async def shutdown(self):
59-
if self._pending:
60-
await asyncio.gather(*self._pending, return_exceptions=True)
61-
if self._client is not None:
62-
await self._client.close()
63+
def _spawn_request(
64+
self: Self,
65+
endpoint: str,
66+
payload: dict[str, Any],
67+
) -> None:
68+
"""Fire and forget helper.
69+
70+
start an async POST to the admin API, keep the resulting Task in _pending
71+
so it can be awaited/cleaned during graceful shutdown.
72+
"""
6373

64-
def _spawn_request(self, endpoint: str, payload: dict[str, Any]) -> None:
6574
async def _send() -> None:
66-
session = self._client or aiohttp.ClientSession(
67-
timeout=aiohttp.ClientTimeout(total=self.timeout)
68-
)
75+
session = self._get_session()
6976

7077
async with session.post(
7178
urljoin(self.url, endpoint),
@@ -80,7 +87,18 @@ async def _send() -> None:
8087
self._pending.add(task)
8188
task.add_done_callback(self._pending.discard)
8289

83-
async def post_send(self, message):
90+
def post_send(
91+
self: Self,
92+
message: TaskiqMessage,
93+
) -> Union[None, Coroutine[Any, Any, None], "CoroutineType[Any, Any, None]"]:
94+
"""
95+
This hook is executed right after the task is sent.
96+
97+
This is a client-side hook. It executes right
98+
after the messages is kicked in broker.
99+
100+
:param message: kicked message.
101+
"""
84102
self._spawn_request(
85103
f"/api/tasks/{message.task_id}/queued",
86104
{
@@ -93,7 +111,23 @@ async def post_send(self, message):
93111
)
94112
return super().post_send(message)
95113

96-
async def pre_execute(self, message: TaskiqMessage):
114+
def pre_execute(
115+
self,
116+
message: TaskiqMessage,
117+
) -> Union[
118+
"TaskiqMessage",
119+
"Coroutine[Any, Any, TaskiqMessage]",
120+
"CoroutineType[Any, Any, TaskiqMessage]",
121+
]:
122+
"""
123+
This hook is called before executing task.
124+
125+
This is a worker-side hook, which means it
126+
executes in the worker process.
127+
128+
:param message: incoming parsed taskiq message.
129+
:return: modified message.
130+
"""
97131
self._spawn_request(
98132
f"/api/tasks/{message.task_id}/started",
99133
{
@@ -106,7 +140,20 @@ async def pre_execute(self, message: TaskiqMessage):
106140
)
107141
return super().pre_execute(message)
108142

109-
async def post_execute(self, message: TaskiqMessage, result: TaskiqResult[Any]):
143+
def post_execute(
144+
self,
145+
message: TaskiqMessage,
146+
result: TaskiqResult[Any],
147+
) -> Union[None, Coroutine[Any, Any, None], "CoroutineType[Any, Any, None]"]:
148+
"""
149+
This hook executes after task is complete.
150+
151+
This is a worker-side hook. It's called
152+
in worker process.
153+
154+
:param message: incoming message.
155+
:param result: result of execution for current task.
156+
"""
110157
self._spawn_request(
111158
f"/api/tasks/{message.task_id}/executed",
112159
{

0 commit comments

Comments
 (0)