44from concurrent .futures import ThreadPoolExecutor
55from logging import getLogger
66from time import time
7- from typing import Any , Callable , Dict
7+ from typing import Any , Callable , Dict , Optional
88
99from taskiq .abc .broker import AsyncBroker
1010from taskiq .abc .middleware import TaskiqMiddleware
1111from taskiq .cli .args import TaskiqArgs
1212from taskiq .cli .log_collector import log_collector
1313from taskiq .cli .params_parser import parse_params
14- from taskiq .context import Context , context_updater
14+ from taskiq .context import Context
1515from taskiq .message import BrokerMessage , TaskiqMessage
1616from taskiq .result import TaskiqResult
1717from taskiq .utils import maybe_awaitable
1818
1919logger = getLogger (__name__ )
2020
2121
22+ def inject_context (
23+ signature : Optional [inspect .Signature ],
24+ message : TaskiqMessage ,
25+ broker : AsyncBroker ,
26+ ) -> None :
27+ """
28+ Inject context parameter in message's kwargs.
29+
30+ This function parses signature to get
31+ the context parameter definition.
32+
33+ If at least one parameter has the Context
34+ type, it will add current context as kwarg.
35+
36+ :param signature: function's signature.
37+ :param message: current taskiq message.
38+ :param broker: current broker.
39+ """
40+ if signature is None :
41+ return
42+ for param_name , param in signature .parameters .items ():
43+ if param .annotation is param .empty :
44+ continue
45+ if param .annotation is Context :
46+ message .kwargs [param_name ] = Context (message .copy (), broker )
47+
48+
2249def _run_sync (target : Callable [..., Any ], message : TaskiqMessage ) -> Any :
2350 """
2451 Runs function synchronously.
@@ -40,11 +67,8 @@ def __init__(self, broker: AsyncBroker, cli_args: TaskiqArgs) -> None:
4067 self .broker = broker
4168 self .cli_args = cli_args
4269 self .task_signatures : Dict [str , inspect .Signature ] = {}
43- if not cli_args .no_parse :
44- for task in self .broker .available_tasks .values ():
45- self .task_signatures [task .task_name ] = inspect .signature (
46- task .original_func ,
47- )
70+ for task in self .broker .available_tasks .values ():
71+ self .task_signatures [task .task_name ] = inspect .signature (task .original_func )
4872 self .executor = ThreadPoolExecutor (
4973 max_workers = cli_args .max_threadpool_threads ,
5074 )
@@ -100,11 +124,10 @@ async def callback( # noqa: C901
100124 taskiq_msg .task_name ,
101125 taskiq_msg .task_id ,
102126 )
103- with context_updater (Context (taskiq_msg , self .broker )):
104- result = await self .run_task (
105- target = self .broker .available_tasks [message .task_name ].original_func ,
106- message = taskiq_msg ,
107- )
127+ result = await self .run_task (
128+ target = self .broker .available_tasks [message .task_name ].original_func ,
129+ message = taskiq_msg ,
130+ )
108131 for middleware in self .broker .middlewares :
109132 if middleware .__class__ .post_execute != TaskiqMiddleware .post_execute :
110133 await maybe_awaitable (middleware .post_execute (taskiq_msg , result ))
@@ -147,7 +170,15 @@ async def run_task( # noqa: C901, WPS210
147170 logs = io .StringIO ()
148171 returned = None
149172 found_exception = None
150- parse_params (self .task_signatures .get (message .task_name ), message )
173+ signature = self .task_signatures .get (message .task_name )
174+ if self .cli_args .no_parse :
175+ signature = None
176+ parse_params (signature , message )
177+ inject_context (
178+ self .task_signatures .get (message .task_name ),
179+ message ,
180+ self .broker ,
181+ )
151182 # Captures function's logs.
152183 with log_collector (logs , self .cli_args .log_collector_format ):
153184 # Start a timer.
0 commit comments