6
6
from concurrent .futures import Executor , ThreadPoolExecutor
7
7
from logging import getLogger
8
8
from time import time
9
- from typing import Any , Callable , Dict , NoReturn , Optional
9
+ from typing import Any , Callable , Dict , List , NoReturn , Optional
10
10
11
11
from pydantic import parse_obj_as
12
12
13
13
from taskiq .abc .broker import AsyncBroker
14
+ from taskiq .abc .middleware import TaskiqMiddleware
14
15
from taskiq .cli .args import TaskiqArgs
15
16
from taskiq .cli .log_collector import log_collector
16
17
from taskiq .message import TaskiqMessage
@@ -97,12 +98,13 @@ def run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
97
98
return target (* message .args , ** message .kwargs )
98
99
99
100
100
- async def run_task ( # noqa: WPS210
101
+ async def run_task ( # noqa: C901, WPS210, WPS211
101
102
target : Callable [..., Any ],
102
103
signature : Optional [inspect .Signature ],
103
104
message : TaskiqMessage ,
104
105
log_collector_format : str ,
105
106
executor : Optional [Executor ] = None ,
107
+ middlewares : Optional [List [TaskiqMiddleware ]] = None ,
106
108
) -> TaskiqResult [Any ]:
107
109
"""
108
110
This function actually executes functions.
@@ -123,12 +125,16 @@ async def run_task( # noqa: WPS210
123
125
:param message: received message.
124
126
:param log_collector_format: Log format in wich logs are collected.
125
127
:param executor: executor to run sync tasks.
128
+ :param middlewares: list of broker's middlewares in case of errors.
126
129
:return: result of execution.
127
130
"""
131
+ if middlewares is None :
132
+ middlewares = []
133
+
128
134
loop = asyncio .get_running_loop ()
129
135
logs = io .StringIO ()
130
- is_err = False
131
136
returned = None
137
+ found_exception = None
132
138
# Captures function's logs.
133
139
parse_params (signature , message )
134
140
with log_collector (logs , log_collector_format ):
@@ -144,7 +150,7 @@ async def run_task( # noqa: WPS210
144
150
message ,
145
151
)
146
152
except Exception as exc :
147
- is_err = True
153
+ found_exception = exc
148
154
logger .error (
149
155
"Exception found while executing function: %s" ,
150
156
exc ,
@@ -154,12 +160,23 @@ async def run_task( # noqa: WPS210
154
160
155
161
raw_logs = logs .getvalue ()
156
162
logs .close ()
157
- return TaskiqResult (
158
- is_err = is_err ,
163
+ result : "TaskiqResult[Any]" = TaskiqResult (
164
+ is_err = found_exception is not None ,
159
165
log = raw_logs ,
160
166
return_value = returned ,
161
167
execution_time = execution_time ,
162
168
)
169
+ if found_exception is not None :
170
+ for middleware in middlewares :
171
+ err_handler = middleware .on_error (
172
+ message ,
173
+ result ,
174
+ found_exception ,
175
+ )
176
+ if inspect .isawaitable (err_handler ):
177
+ await err_handler
178
+
179
+ return result
163
180
164
181
165
182
def exit_process (task : "asyncio.Task[Any]" ) -> NoReturn :
@@ -301,11 +318,12 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213
301
318
else :
302
319
taskiq_msg = pre_ex_res # type: ignore
303
320
result = await run_task (
304
- broker .available_tasks [message .task_name ].original_func ,
305
- task_signatures .get (message .task_name ),
306
- taskiq_msg ,
307
- cli_args .log_collector_format ,
308
- executor ,
321
+ target = broker .available_tasks [message .task_name ].original_func ,
322
+ signature = task_signatures .get (message .task_name ),
323
+ message = taskiq_msg ,
324
+ log_collector_format = cli_args .log_collector_format ,
325
+ executor = executor ,
326
+ middlewares = broker .middlewares ,
309
327
)
310
328
for middleware in broker .middlewares :
311
329
post_ex_res = middleware .post_execute (
0 commit comments