11import asyncio
22from datetime import UTC , datetime
33from logging import getLogger
4- from typing import Any
4+ from types import CoroutineType
5+ from typing import Any , Coroutine , Self , Union
56from urllib .parse import urljoin
67
78import aiohttp
@@ -37,7 +38,7 @@ def __init__(
3738 api_token : str ,
3839 timeout : int = 5 ,
3940 taskiq_broker_name : str | None = None ,
40- ):
41+ ) -> None :
4142 super ().__init__ ()
4243 self .url = url
4344 self .timeout = timeout
@@ -50,22 +51,28 @@ def __init__(
5051 def _now_iso () -> str :
5152 return datetime .now (UTC ).replace (tzinfo = None ).isoformat ()
5253
53- async def startup (self ):
54- self ._client = aiohttp .ClientSession (
55- timeout = aiohttp .ClientTimeout (total = self .timeout ),
56- )
54+ def _get_session (self : Self ) -> aiohttp .ClientSession :
55+ """Create and cache session."""
56+ if self ._client is None or self ._client .closed :
57+ self ._client = aiohttp .ClientSession (
58+ timeout = aiohttp .ClientTimeout (total = self .timeout ),
59+ )
60+
61+ return self ._client
5762
58- async def shutdown (self ):
59- if self ._pending :
60- await asyncio .gather (* self ._pending , return_exceptions = True )
61- if self ._client is not None :
62- await self ._client .close ()
63+ def _spawn_request (
64+ self : Self ,
65+ endpoint : str ,
66+ payload : dict [str , Any ],
67+ ) -> None :
68+ """Fire and forget helper.
69+
70+ start an async POST to the admin API, keep the resulting Task in _pending
71+ so it can be awaited/cleaned during graceful shutdown.
72+ """
6373
64- def _spawn_request (self , endpoint : str , payload : dict [str , Any ]) -> None :
6574 async def _send () -> None :
66- session = self ._client or aiohttp .ClientSession (
67- timeout = aiohttp .ClientTimeout (total = self .timeout )
68- )
75+ session = self ._get_session ()
6976
7077 async with session .post (
7178 urljoin (self .url , endpoint ),
@@ -80,7 +87,18 @@ async def _send() -> None:
8087 self ._pending .add (task )
8188 task .add_done_callback (self ._pending .discard )
8289
83- async def post_send (self , message ):
90+ def post_send (
91+ self : Self ,
92+ message : TaskiqMessage ,
93+ ) -> Union [None , Coroutine [Any , Any , None ], "CoroutineType[Any, Any, None]" ]:
94+ """
95+ This hook is executed right after the task is sent.
96+
97+ This is a client-side hook. It executes right
98+ after the messages is kicked in broker.
99+
100+ :param message: kicked message.
101+ """
84102 self ._spawn_request (
85103 f"/api/tasks/{ message .task_id } /queued" ,
86104 {
@@ -93,7 +111,23 @@ async def post_send(self, message):
93111 )
94112 return super ().post_send (message )
95113
96- async def pre_execute (self , message : TaskiqMessage ):
114+ def pre_execute (
115+ self ,
116+ message : TaskiqMessage ,
117+ ) -> Union [
118+ "TaskiqMessage" ,
119+ "Coroutine[Any, Any, TaskiqMessage]" ,
120+ "CoroutineType[Any, Any, TaskiqMessage]" ,
121+ ]:
122+ """
123+ This hook is called before executing task.
124+
125+ This is a worker-side hook, which means it
126+ executes in the worker process.
127+
128+ :param message: incoming parsed taskiq message.
129+ :return: modified message.
130+ """
97131 self ._spawn_request (
98132 f"/api/tasks/{ message .task_id } /started" ,
99133 {
@@ -106,7 +140,20 @@ async def pre_execute(self, message: TaskiqMessage):
106140 )
107141 return super ().pre_execute (message )
108142
109- async def post_execute (self , message : TaskiqMessage , result : TaskiqResult [Any ]):
143+ def post_execute (
144+ self ,
145+ message : TaskiqMessage ,
146+ result : TaskiqResult [Any ],
147+ ) -> Union [None , Coroutine [Any , Any , None ], "CoroutineType[Any, Any, None]" ]:
148+ """
149+ This hook executes after task is complete.
150+
151+ This is a worker-side hook. It's called
152+ in worker process.
153+
154+ :param message: incoming message.
155+ :param result: result of execution for current task.
156+ """
110157 self ._spawn_request (
111158 f"/api/tasks/{ message .task_id } /executed" ,
112159 {
0 commit comments