Skip to content

Commit 7b6609b

Browse files
committed
Added on_error hook in middleware.
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent c72a62b commit 7b6609b

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

taskiq/abc/middleware.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,17 @@ def post_execute(
8383
:param result: result of execution for current task.
8484
:param labels: task's labels. Without user-supplied labels.
8585
"""
86+
87+
def on_error(
88+
self,
89+
message: "TaskiqMessage",
90+
result: "TaskiqResult[Any]",
91+
exception: Exception,
92+
) -> "Union[None, Coroutine[Any, Any, None]]":
93+
"""
94+
This function is called when exception is found.
95+
96+
:param message: incoming message.
97+
:param result: returned value.
98+
:param exception: found exception.
99+
"""

taskiq/brokers/inmemory_broker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ async def kick(self, message: BrokerMessage) -> None:
125125
message=taskiq_message,
126126
log_collector_format=self.logs_format,
127127
executor=self.executor,
128+
middlewares=self.middlewares,
128129
)
129130
await self.result_backend.set_result(message.task_id, result)
130131

taskiq/cli/async_task_runner.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from concurrent.futures import Executor, ThreadPoolExecutor
77
from logging import getLogger
88
from time import time
9-
from typing import Any, Callable, Dict, NoReturn, Optional
9+
from typing import Any, Callable, Dict, List, NoReturn, Optional
1010

1111
from pydantic import parse_obj_as
1212

1313
from taskiq.abc.broker import AsyncBroker
14+
from taskiq.abc.middleware import TaskiqMiddleware
1415
from taskiq.cli.args import TaskiqArgs
1516
from taskiq.cli.log_collector import log_collector
1617
from taskiq.message import TaskiqMessage
@@ -97,12 +98,13 @@ def run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
9798
return target(*message.args, **message.kwargs)
9899

99100

100-
async def run_task( # noqa: WPS210
101+
async def run_task( # noqa: C901, WPS210, WPS211
101102
target: Callable[..., Any],
102103
signature: Optional[inspect.Signature],
103104
message: TaskiqMessage,
104105
log_collector_format: str,
105106
executor: Optional[Executor] = None,
107+
middlewares: Optional[List[TaskiqMiddleware]] = None,
106108
) -> TaskiqResult[Any]:
107109
"""
108110
This function actually executes functions.
@@ -123,12 +125,16 @@ async def run_task( # noqa: WPS210
123125
:param message: received message.
124126
:param log_collector_format: Log format in wich logs are collected.
125127
:param executor: executor to run sync tasks.
128+
:param middlewares: list of broker's middlewares in case of errors.
126129
:return: result of execution.
127130
"""
131+
if middlewares is None:
132+
middlewares = []
133+
128134
loop = asyncio.get_running_loop()
129135
logs = io.StringIO()
130-
is_err = False
131136
returned = None
137+
found_exception = None
132138
# Captures function's logs.
133139
parse_params(signature, message)
134140
with log_collector(logs, log_collector_format):
@@ -144,7 +150,7 @@ async def run_task( # noqa: WPS210
144150
message,
145151
)
146152
except Exception as exc:
147-
is_err = True
153+
found_exception = exc
148154
logger.error(
149155
"Exception found while executing function: %s",
150156
exc,
@@ -154,12 +160,23 @@ async def run_task( # noqa: WPS210
154160

155161
raw_logs = logs.getvalue()
156162
logs.close()
157-
return TaskiqResult(
158-
is_err=is_err,
163+
result: "TaskiqResult[Any]" = TaskiqResult(
164+
is_err=found_exception is not None,
159165
log=raw_logs,
160166
return_value=returned,
161167
execution_time=execution_time,
162168
)
169+
if found_exception is not None:
170+
for middleware in middlewares:
171+
err_handler = middleware.on_error(
172+
message,
173+
result,
174+
found_exception,
175+
)
176+
if inspect.isawaitable(err_handler):
177+
await err_handler
178+
179+
return result
163180

164181

165182
def exit_process(task: "asyncio.Task[Any]") -> NoReturn:
@@ -301,11 +318,12 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213
301318
else:
302319
taskiq_msg = pre_ex_res # type: ignore
303320
result = await run_task(
304-
broker.available_tasks[message.task_name].original_func,
305-
task_signatures.get(message.task_name),
306-
taskiq_msg,
307-
cli_args.log_collector_format,
308-
executor,
321+
target=broker.available_tasks[message.task_name].original_func,
322+
signature=task_signatures.get(message.task_name),
323+
message=taskiq_msg,
324+
log_collector_format=cli_args.log_collector_format,
325+
executor=executor,
326+
middlewares=broker.middlewares,
309327
)
310328
for middleware in broker.middlewares:
311329
post_ex_res = middleware.post_execute(

0 commit comments

Comments
 (0)