1- import inspect
21import os
32import sys
43from abc import ABC , abstractmethod
4+ from collections import defaultdict
55from functools import wraps
66from logging import getLogger
77from typing import ( # noqa: WPS235
88 TYPE_CHECKING ,
99 Any ,
10+ Awaitable ,
1011 Callable ,
1112 Coroutine ,
13+ DefaultDict ,
1214 Dict ,
1315 List ,
1416 Optional ,
1820)
1921from uuid import uuid4
2022
21- from typing_extensions import ParamSpec
23+ from typing_extensions import ParamSpec , TypeAlias
2224
25+ from taskiq .abc .middleware import TaskiqMiddleware
2326from taskiq .decor import AsyncTaskiqDecoratedTask
27+ from taskiq .events import TaskiqEvents
2428from taskiq .formatters .json_formatter import JSONFormatter
2529from taskiq .message import BrokerMessage
2630from taskiq .result_backends .dummy import DummyResultBackend
31+ from taskiq .state import TaskiqState
32+ from taskiq .utils import maybe_awaitable
2733
28- if TYPE_CHECKING :
34+ if TYPE_CHECKING : # pragma: no cover
2935 from taskiq .abc .formatter import TaskiqFormatter
30- from taskiq .abc .middleware import TaskiqMiddleware
3136 from taskiq .abc .result_backend import AsyncResultBackend
3237
3338_T = TypeVar ("_T" ) # noqa: WPS111
3439_FuncParams = ParamSpec ("_FuncParams" )
3540_ReturnType = TypeVar ("_ReturnType" )
3641
42+ EventHandler : TypeAlias = Callable [[TaskiqState ], Optional [Awaitable [None ]]]
43+
3744logger = getLogger ("taskiq" )
3845
3946
@@ -49,7 +56,7 @@ def default_id_generator() -> str:
4956 return uuid4 ().hex
5057
5158
52- class AsyncBroker (ABC ):
59+ class AsyncBroker (ABC ): # noqa: WPS230
5360 """
5461 Async broker.
5562
@@ -75,8 +82,16 @@ def __init__(
7582 self .decorator_class = AsyncTaskiqDecoratedTask
7683 self .formatter : "TaskiqFormatter" = JSONFormatter ()
7784 self .id_generator = task_id_generator
78-
79- def add_middlewares (self , middlewares : "List[TaskiqMiddleware]" ) -> None :
85+ # Every event has a list of handlers.
86+ # Every handler is a function which takes state as a first argument.
87+ # And handler can be either sync or async.
88+ self .event_handlers : DefaultDict [ # noqa: WPS234
89+ TaskiqEvents ,
90+ List [Callable [[TaskiqState ], Optional [Awaitable [None ]]]],
91+ ] = defaultdict (list )
92+ self .state = TaskiqState ()
93+
94+ def add_middlewares (self , * middlewares : "TaskiqMiddleware" ) -> None :
8095 """
8196 Add a list of middlewares.
8297
@@ -86,11 +101,23 @@ def add_middlewares(self, middlewares: "List[TaskiqMiddleware]") -> None:
86101 :param middlewares: list of middlewares.
87102 """
88103 for middleware in middlewares :
104+ if not isinstance (middleware , TaskiqMiddleware ):
105+ logger .warning (
106+ f"Middleware { middleware } is not an instance of TaskiqMiddleware. "
107+ "Skipping..." ,
108+ )
109+ continue
89110 middleware .set_broker (self )
90111 self .middlewares .append (middleware )
91112
92113 async def startup (self ) -> None :
93114 """Do something when starting broker."""
115+ event = TaskiqEvents .CLIENT_STARTUP
116+ if self .is_worker_process :
117+ event = TaskiqEvents .WORKER_STARTUP
118+
119+ for handler in self .event_handlers [event ]:
120+ await maybe_awaitable (handler (self .state ))
94121
95122 async def shutdown (self ) -> None :
96123 """
@@ -99,11 +126,13 @@ async def shutdown(self) -> None:
99126 This method is called,
100127 when broker is closig.
101128 """
102- for middleware in self .middlewares :
103- middleware_shutdown = middleware .shutdown ()
104- if inspect .isawaitable (middleware_shutdown ):
105- await middleware_shutdown
106- await self .result_backend .shutdown ()
129+ event = TaskiqEvents .CLIENT_SHUTDOWN
130+ if self .is_worker_process :
131+ event = TaskiqEvents .WORKER_SHUTDOWN
132+
133+ # Call all shutdown events.
134+ for handler in self .event_handlers [event ]:
135+ await maybe_awaitable (handler (self .state ))
107136
108137 @abstractmethod
109138 async def kick (
@@ -232,3 +261,43 @@ def inner(
232261 inner_task_name = task_name ,
233262 inner_labels = labels or {},
234263 )
264+
265+ def on_event (self , * events : TaskiqEvents ) -> Callable [[EventHandler ], EventHandler ]:
266+ """
267+ Adds event handler.
268+
269+ This function adds function to call when event occurs.
270+
271+ :param events: events to react to.
272+ :return: a decorator function.
273+ """
274+
275+ def handler (function : EventHandler ) -> EventHandler :
276+ for event in events :
277+ self .event_handlers [event ].append (function )
278+ return function
279+
280+ return handler
281+
282+ def add_event_handler (
283+ self ,
284+ event : TaskiqEvents ,
285+ handler : EventHandler ,
286+ ) -> None :
287+ """
288+ Adds event handler.
289+
290+ this function is the same as on_event.
291+
292+ >>> broker.add_event_handler(TaskiqEvents.WORKER_STARTUP, my_startup)
293+
294+ if similar to:
295+
296+ >>> @broker.on_event(TaskiqEvents.WORKER_STARTUP)
297+ >>> async def my_startup(context: Context) -> None:
298+ >>> ...
299+
300+ :param event: Event to react to.
301+ :param handler: handler to call when event is started.
302+ """
303+ self .event_handlers [event ].append (handler )
0 commit comments