22from datetime import UTC , datetime
33from logging import getLogger
44from types import CoroutineType
5- from typing import Any , Coroutine , Self , Union
5+ from typing import Any , Coroutine , Union
66from urllib .parse import urljoin
77
88import aiohttp
@@ -51,7 +51,7 @@ def __init__(
5151 def _now_iso () -> str :
5252 return datetime .now (UTC ).replace (tzinfo = None ).isoformat ()
5353
54- def _get_session (self : Self ) -> aiohttp .ClientSession :
54+ def _get_client (self ) -> aiohttp .ClientSession :
5555 """Create and cache session."""
5656 if self ._client is None or self ._client .closed :
5757 self ._client = aiohttp .ClientSession (
@@ -60,8 +60,26 @@ def _get_session(self: Self) -> aiohttp.ClientSession:
6060
6161 return self ._client
6262
63- def _spawn_request (
64- self : Self ,
63+ async def startup (self ) -> None :
64+ """
65+ Startup method to initialize aiohttp.ClientSession.
66+
67+ :returns nothing.
68+ """
69+ self ._client = self ._get_client ()
70+
71+ async def shutdown (self ) -> None :
72+ """Shutdown method to run all pending requests and close the session.
73+
74+ :returns nothing.
75+ """
76+ if self ._pending :
77+ await asyncio .gather (* self ._pending , return_exceptions = True )
78+ if self ._client is not None :
79+ await self ._client .close ()
80+
81+ async def _spawn_request (
82+ self ,
6583 endpoint : str ,
6684 payload : dict [str , Any ],
6785 ) -> None :
@@ -72,9 +90,9 @@ def _spawn_request(
7290 """
7391
7492 async def _send () -> None :
75- session = self ._get_session ()
93+ client = self ._get_client ()
7694
77- async with session .post (
95+ async with client .post (
7896 urljoin (self .url , endpoint ),
7997 headers = {"access-token" : self .api_token },
8098 json = payload ,
@@ -87,8 +105,8 @@ async def _send() -> None:
87105 self ._pending .add (task )
88106 task .add_done_callback (self ._pending .discard )
89107
90- def post_send (
91- self : Self ,
108+ async def post_send (
109+ self ,
92110 message : TaskiqMessage ,
93111 ) -> Union [None , Coroutine [Any , Any , None ], "CoroutineType[Any, Any, None]" ]:
94112 """
@@ -99,7 +117,7 @@ def post_send(
99117
100118 :param message: kicked message.
101119 """
102- self ._spawn_request (
120+ await self ._spawn_request (
103121 f"/api/tasks/{ message .task_id } /queued" ,
104122 {
105123 "args" : message .args ,
@@ -111,7 +129,7 @@ def post_send(
111129 )
112130 return super ().post_send (message )
113131
114- def pre_execute (
132+ async def pre_execute (
115133 self ,
116134 message : TaskiqMessage ,
117135 ) -> Union [
@@ -128,7 +146,7 @@ def pre_execute(
128146 :param message: incoming parsed taskiq message.
129147 :return: modified message.
130148 """
131- self ._spawn_request (
149+ await self ._spawn_request (
132150 f"/api/tasks/{ message .task_id } /started" ,
133151 {
134152 "args" : message .args ,
@@ -140,7 +158,7 @@ def pre_execute(
140158 )
141159 return super ().pre_execute (message )
142160
143- def post_execute (
161+ async def post_execute (
144162 self ,
145163 message : TaskiqMessage ,
146164 result : TaskiqResult [Any ],
@@ -154,7 +172,7 @@ def post_execute(
154172 :param message: incoming message.
155173 :param result: result of execution for current task.
156174 """
157- self ._spawn_request (
175+ await self ._spawn_request (
158176 f"/api/tasks/{ message .task_id } /executed" ,
159177 {
160178 "finishedAt" : self ._now_iso (),
0 commit comments