2
2
from datetime import UTC , datetime
3
3
from logging import getLogger
4
4
from types import CoroutineType
5
- from typing import Any , Coroutine , Self , Union
5
+ from typing import Any , Coroutine , Union
6
6
from urllib .parse import urljoin
7
7
8
8
import aiohttp
@@ -51,7 +51,7 @@ def __init__(
51
51
def _now_iso () -> str :
52
52
return datetime .now (UTC ).replace (tzinfo = None ).isoformat ()
53
53
54
- def _get_session (self : Self ) -> aiohttp .ClientSession :
54
+ def _get_client (self ) -> aiohttp .ClientSession :
55
55
"""Create and cache session."""
56
56
if self ._client is None or self ._client .closed :
57
57
self ._client = aiohttp .ClientSession (
@@ -60,8 +60,26 @@ def _get_session(self: Self) -> aiohttp.ClientSession:
60
60
61
61
return self ._client
62
62
63
- def _spawn_request (
64
- self : Self ,
63
+ async def startup (self ) -> None :
64
+ """
65
+ Startup method to initialize aiohttp.ClientSession.
66
+
67
+ :returns nothing.
68
+ """
69
+ self ._client = self ._get_client ()
70
+
71
+ async def shutdown (self ) -> None :
72
+ """Shutdown method to run all pending requests and close the session.
73
+
74
+ :returns nothing.
75
+ """
76
+ if self ._pending :
77
+ await asyncio .gather (* self ._pending , return_exceptions = True )
78
+ if self ._client is not None :
79
+ await self ._client .close ()
80
+
81
+ async def _spawn_request (
82
+ self ,
65
83
endpoint : str ,
66
84
payload : dict [str , Any ],
67
85
) -> None :
@@ -72,9 +90,9 @@ def _spawn_request(
72
90
"""
73
91
74
92
async def _send () -> None :
75
- session = self ._get_session ()
93
+ client = self ._get_client ()
76
94
77
- async with session .post (
95
+ async with client .post (
78
96
urljoin (self .url , endpoint ),
79
97
headers = {"access-token" : self .api_token },
80
98
json = payload ,
@@ -87,8 +105,8 @@ async def _send() -> None:
87
105
self ._pending .add (task )
88
106
task .add_done_callback (self ._pending .discard )
89
107
90
- def post_send (
91
- self : Self ,
108
+ async def post_send (
109
+ self ,
92
110
message : TaskiqMessage ,
93
111
) -> Union [None , Coroutine [Any , Any , None ], "CoroutineType[Any, Any, None]" ]:
94
112
"""
@@ -99,7 +117,7 @@ def post_send(
99
117
100
118
:param message: kicked message.
101
119
"""
102
- self ._spawn_request (
120
+ await self ._spawn_request (
103
121
f"/api/tasks/{ message .task_id } /queued" ,
104
122
{
105
123
"args" : message .args ,
@@ -111,7 +129,7 @@ def post_send(
111
129
)
112
130
return super ().post_send (message )
113
131
114
- def pre_execute (
132
+ async def pre_execute (
115
133
self ,
116
134
message : TaskiqMessage ,
117
135
) -> Union [
@@ -128,7 +146,7 @@ def pre_execute(
128
146
:param message: incoming parsed taskiq message.
129
147
:return: modified message.
130
148
"""
131
- self ._spawn_request (
149
+ await self ._spawn_request (
132
150
f"/api/tasks/{ message .task_id } /started" ,
133
151
{
134
152
"args" : message .args ,
@@ -140,7 +158,7 @@ def pre_execute(
140
158
)
141
159
return super ().pre_execute (message )
142
160
143
- def post_execute (
161
+ async def post_execute (
144
162
self ,
145
163
message : TaskiqMessage ,
146
164
result : TaskiqResult [Any ],
@@ -154,7 +172,7 @@ def post_execute(
154
172
:param message: incoming message.
155
173
:param result: result of execution for current task.
156
174
"""
157
- self ._spawn_request (
175
+ await self ._spawn_request (
158
176
f"/api/tasks/{ message .task_id } /executed" ,
159
177
{
160
178
"finishedAt" : self ._now_iso (),
0 commit comments