Skip to content

Commit 97ed039

Browse files
committed
feat: make methods async, add startup/shutdown
1 parent ee998a6 commit 97ed039

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

taskiq/middlewares/taskiq_admin_middleware.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import UTC, datetime
33
from logging import getLogger
44
from types import CoroutineType
5-
from typing import Any, Coroutine, Self, Union
5+
from typing import Any, Coroutine, Union
66
from urllib.parse import urljoin
77

88
import aiohttp
@@ -51,7 +51,7 @@ def __init__(
5151
def _now_iso() -> str:
5252
return datetime.now(UTC).replace(tzinfo=None).isoformat()
5353

54-
def _get_session(self: Self) -> aiohttp.ClientSession:
54+
def _get_client(self) -> aiohttp.ClientSession:
5555
"""Create and cache session."""
5656
if self._client is None or self._client.closed:
5757
self._client = aiohttp.ClientSession(
@@ -60,8 +60,26 @@ def _get_session(self: Self) -> aiohttp.ClientSession:
6060

6161
return self._client
6262

63-
def _spawn_request(
64-
self: Self,
63+
async def startup(self) -> None:
64+
"""
65+
Startup method to initialize aiohttp.ClientSession.
66+
67+
:returns nothing.
68+
"""
69+
self._client = self._get_client()
70+
71+
async def shutdown(self) -> None:
72+
"""Shutdown method to run all pending requests and close the session.
73+
74+
:returns nothing.
75+
"""
76+
if self._pending:
77+
await asyncio.gather(*self._pending, return_exceptions=True)
78+
if self._client is not None:
79+
await self._client.close()
80+
81+
async def _spawn_request(
82+
self,
6583
endpoint: str,
6684
payload: dict[str, Any],
6785
) -> None:
@@ -72,9 +90,9 @@ def _spawn_request(
7290
"""
7391

7492
async def _send() -> None:
75-
session = self._get_session()
93+
client = self._get_client()
7694

77-
async with session.post(
95+
async with client.post(
7896
urljoin(self.url, endpoint),
7997
headers={"access-token": self.api_token},
8098
json=payload,
@@ -87,8 +105,8 @@ async def _send() -> None:
87105
self._pending.add(task)
88106
task.add_done_callback(self._pending.discard)
89107

90-
def post_send(
91-
self: Self,
108+
async def post_send(
109+
self,
92110
message: TaskiqMessage,
93111
) -> Union[None, Coroutine[Any, Any, None], "CoroutineType[Any, Any, None]"]:
94112
"""
@@ -99,7 +117,7 @@ def post_send(
99117
100118
:param message: kicked message.
101119
"""
102-
self._spawn_request(
120+
await self._spawn_request(
103121
f"/api/tasks/{message.task_id}/queued",
104122
{
105123
"args": message.args,
@@ -111,7 +129,7 @@ def post_send(
111129
)
112130
return super().post_send(message)
113131

114-
def pre_execute(
132+
async def pre_execute(
115133
self,
116134
message: TaskiqMessage,
117135
) -> Union[
@@ -128,7 +146,7 @@ def pre_execute(
128146
:param message: incoming parsed taskiq message.
129147
:return: modified message.
130148
"""
131-
self._spawn_request(
149+
await self._spawn_request(
132150
f"/api/tasks/{message.task_id}/started",
133151
{
134152
"args": message.args,
@@ -140,7 +158,7 @@ def pre_execute(
140158
)
141159
return super().pre_execute(message)
142160

143-
def post_execute(
161+
async def post_execute(
144162
self,
145163
message: TaskiqMessage,
146164
result: TaskiqResult[Any],
@@ -154,7 +172,7 @@ def post_execute(
154172
:param message: incoming message.
155173
:param result: result of execution for current task.
156174
"""
157-
self._spawn_request(
175+
await self._spawn_request(
158176
f"/api/tasks/{message.task_id}/executed",
159177
{
160178
"finishedAt": self._now_iso(),

0 commit comments

Comments
 (0)