Skip to content

Commit c08d06e

Browse files
authored
Merge pull request #23 from taskiq-python/feature/retries
Feature/retries
2 parents c72a62b + 0a7c595 commit c08d06e

File tree

13 files changed

+158
-57
lines changed

13 files changed

+158
-57
lines changed

taskiq/__init__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
"""Distributed task manager."""
22
from taskiq.abc.broker import AsyncBroker, AsyncTaskiqDecoratedTask
3+
from taskiq.abc.formatter import TaskiqFormatter
4+
from taskiq.abc.middleware import TaskiqMiddleware
35
from taskiq.abc.result_backend import AsyncResultBackend
4-
from taskiq.message import TaskiqMessage
6+
from taskiq.brokers.shared_broker import async_shared_broker
7+
from taskiq.exceptions import TaskiqError
8+
from taskiq.message import BrokerMessage, TaskiqMessage
9+
from taskiq.result import TaskiqResult
510
from taskiq.task import AsyncTaskiqTask
611

712
__all__ = [
8-
"TaskiqMessage",
913
"AsyncBroker",
10-
"AsyncTaskiqDecoratedTask",
11-
"AsyncResultBackend",
14+
"TaskiqError",
15+
"TaskiqResult",
16+
"TaskiqMessage",
17+
"BrokerMessage",
18+
"TaskiqFormatter",
1219
"AsyncTaskiqTask",
20+
"TaskiqMiddleware",
21+
"AsyncResultBackend",
22+
"async_shared_broker",
23+
"AsyncTaskiqDecoratedTask",
1324
]

taskiq/abc/broker.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from abc import ABC, abstractmethod
23
from functools import wraps
34
from logging import getLogger
@@ -13,12 +14,13 @@
1314
Union,
1415
overload,
1516
)
17+
from uuid import uuid4
1618

1719
from typing_extensions import ParamSpec
1820

1921
from taskiq.decor import AsyncTaskiqDecoratedTask
22+
from taskiq.formatters.json_formatter import JSONFormatter
2023
from taskiq.message import BrokerMessage
21-
from taskiq.plugins.json_formatter import JSONFormatter
2224
from taskiq.result_backends.dummy import DummyResultBackend
2325

2426
if TYPE_CHECKING:
@@ -33,6 +35,18 @@
3335
logger = getLogger("taskiq")
3436

3537

38+
def default_id_generator() -> str:
39+
"""
40+
Default task_id generator.
41+
42+
This function is used to generate id's
43+
for tasks.
44+
45+
:return: new task_id.
46+
"""
47+
return uuid4().hex
48+
49+
3650
class AsyncBroker(ABC):
3751
"""
3852
Async broker.
@@ -47,14 +61,18 @@ class AsyncBroker(ABC):
4761
def __init__(
4862
self,
4963
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
64+
task_id_generator: Optional[Callable[[], str]] = None,
5065
) -> None:
5166
if result_backend is None:
5267
result_backend = DummyResultBackend()
68+
if task_id_generator is None:
69+
task_id_generator = default_id_generator
5370
self.middlewares: "List[TaskiqMiddleware]" = []
5471
self.result_backend = result_backend
5572
self.is_worker_process = False
5673
self.decorator_class = AsyncTaskiqDecoratedTask
5774
self.formatter: "TaskiqFormatter" = JSONFormatter()
75+
self.id_generator = task_id_generator
5876

5977
def add_middlewares(self, middlewares: "List[TaskiqMiddleware]") -> None:
6078
"""
@@ -79,6 +97,10 @@ async def shutdown(self) -> None:
7997
This method is called,
8098
when broker is closig.
8199
"""
100+
for middleware in self.middlewares:
101+
middleware_shutdown = middleware.shutdown()
102+
if inspect.isawaitable(middleware_shutdown):
103+
await middleware_shutdown
82104

83105
@abstractmethod
84106
async def kick(

taskiq/abc/formatter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict
32

43
from taskiq.message import BrokerMessage, TaskiqMessage
54

@@ -8,12 +7,11 @@ class TaskiqFormatter(ABC):
87
"""Custom formatter for brokers."""
98

109
@abstractmethod
11-
def dumps(self, message: TaskiqMessage, labels: Dict[str, Any]) -> BrokerMessage:
10+
def dumps(self, message: TaskiqMessage) -> BrokerMessage:
1211
"""
1312
Dump message to broker message instance.
1413
1514
:param message: message to send.
16-
:param labels: task's labels.
1715
:return: message for brokers.
1816
"""
1917

taskiq/abc/middleware.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Union
1+
from typing import TYPE_CHECKING, Any, Coroutine, Union
22

33
if TYPE_CHECKING:
44
from taskiq.abc.broker import AsyncBroker
@@ -20,10 +20,12 @@ def set_broker(self, broker: "AsyncBroker") -> None:
2020
"""
2121
self.broker = broker
2222

23+
def shutdown(self) -> Union[None, Coroutine[Any, Any, None]]:
24+
"""This function is used to do some work on shutdown."""
25+
2326
def pre_send(
2427
self,
2528
message: "TaskiqMessage",
26-
labels: Dict[str, Any],
2729
) -> "Union[TaskiqMessage, Coroutine[Any, Any, TaskiqMessage]]":
2830
"""
2931
Hook that executes before sending the task to worker.
@@ -32,15 +34,13 @@ def pre_send(
3234
the message is sent to broker.
3335
3436
:param message: message to send.
35-
:param labels: task's labels.
3637
:return: modified message.
3738
"""
3839
return message
3940

4041
def post_send(
4142
self,
4243
message: "TaskiqMessage",
43-
labels: Dict[str, Any],
4444
) -> "Union[None, Coroutine[Any, Any, None]]":
4545
"""
4646
This hook is executed right after the task is sent.
@@ -49,13 +49,11 @@ def post_send(
4949
after the messages is kicked in broker.
5050
5151
:param message: kicked message.
52-
:param labels: labels for a message.
5352
"""
5453

5554
def pre_execute(
5655
self,
5756
message: "TaskiqMessage",
58-
labels: Dict[str, Any],
5957
) -> "Union[TaskiqMessage, Coroutine[Any, Any, TaskiqMessage]]":
6058
"""
6159
This hook is called before executing task.
@@ -64,22 +62,35 @@ def pre_execute(
6462
executes in the worker process.
6563
6664
:param message: incoming parsed taskiq message.
67-
:param labels: task's labels without user-supplied lables.
6865
:return: modified message.
6966
"""
7067
return message
7168

7269
def post_execute(
7370
self,
71+
message: "TaskiqMessage",
7472
result: "TaskiqResult[Any]",
75-
labels: Dict[str, Any],
7673
) -> "Union[None, Coroutine[Any, Any, None]]":
7774
"""
7875
This hook executes after task is complete.
7976
8077
This is a worker-side hook. It's called
8178
in worker process.
8279
80+
:param message: incoming message.
8381
:param result: result of execution for current task.
84-
:param labels: task's labels. Without user-supplied labels.
82+
"""
83+
84+
def on_error(
85+
self,
86+
message: "TaskiqMessage",
87+
result: "TaskiqResult[Any]",
88+
exception: Exception,
89+
) -> "Union[None, Coroutine[Any, Any, None]]":
90+
"""
91+
This function is called when exception is found.
92+
93+
:param message: incoming message.
94+
:param result: returned value.
95+
:param exception: found exception.
8596
"""

taskiq/brokers/inmemory_broker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ async def kick(self, message: BrokerMessage) -> None:
125125
message=taskiq_message,
126126
log_collector_format=self.logs_format,
127127
executor=self.executor,
128+
middlewares=self.middlewares,
128129
)
129130
await self.result_backend.set_result(message.task_id, result)
130131

taskiq/cli/async_task_runner.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from concurrent.futures import Executor, ThreadPoolExecutor
77
from logging import getLogger
88
from time import time
9-
from typing import Any, Callable, Dict, NoReturn, Optional
9+
from typing import Any, Callable, Dict, List, NoReturn, Optional
1010

1111
from pydantic import parse_obj_as
1212

1313
from taskiq.abc.broker import AsyncBroker
14+
from taskiq.abc.middleware import TaskiqMiddleware
1415
from taskiq.cli.args import TaskiqArgs
1516
from taskiq.cli.log_collector import log_collector
1617
from taskiq.message import TaskiqMessage
@@ -97,12 +98,13 @@ def run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
9798
return target(*message.args, **message.kwargs)
9899

99100

100-
async def run_task( # noqa: WPS210
101+
async def run_task( # noqa: C901, WPS210, WPS211
101102
target: Callable[..., Any],
102103
signature: Optional[inspect.Signature],
103104
message: TaskiqMessage,
104105
log_collector_format: str,
105106
executor: Optional[Executor] = None,
107+
middlewares: Optional[List[TaskiqMiddleware]] = None,
106108
) -> TaskiqResult[Any]:
107109
"""
108110
This function actually executes functions.
@@ -123,12 +125,16 @@ async def run_task( # noqa: WPS210
123125
:param message: received message.
124126
:param log_collector_format: Log format in wich logs are collected.
125127
:param executor: executor to run sync tasks.
128+
:param middlewares: list of broker's middlewares in case of errors.
126129
:return: result of execution.
127130
"""
131+
if middlewares is None:
132+
middlewares = []
133+
128134
loop = asyncio.get_running_loop()
129135
logs = io.StringIO()
130-
is_err = False
131136
returned = None
137+
found_exception = None
132138
# Captures function's logs.
133139
parse_params(signature, message)
134140
with log_collector(logs, log_collector_format):
@@ -144,7 +150,7 @@ async def run_task( # noqa: WPS210
144150
message,
145151
)
146152
except Exception as exc:
147-
is_err = True
153+
found_exception = exc
148154
logger.error(
149155
"Exception found while executing function: %s",
150156
exc,
@@ -154,12 +160,23 @@ async def run_task( # noqa: WPS210
154160

155161
raw_logs = logs.getvalue()
156162
logs.close()
157-
return TaskiqResult(
158-
is_err=is_err,
163+
result: "TaskiqResult[Any]" = TaskiqResult(
164+
is_err=found_exception is not None,
159165
log=raw_logs,
160166
return_value=returned,
161167
execution_time=execution_time,
162168
)
169+
if found_exception is not None:
170+
for middleware in middlewares:
171+
err_handler = middleware.on_error(
172+
message,
173+
result,
174+
found_exception,
175+
)
176+
if inspect.isawaitable(err_handler):
177+
await err_handler
178+
179+
return result
163180

164181

165182
def exit_process(task: "asyncio.Task[Any]") -> NoReturn:
@@ -294,24 +311,21 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213
294311
for middleware in broker.middlewares:
295312
pre_ex_res = middleware.pre_execute(
296313
taskiq_msg,
297-
broker.available_tasks[message.task_name].labels,
298314
)
299315
if inspect.isawaitable(pre_ex_res):
300316
taskiq_msg = await pre_ex_res
301317
else:
302318
taskiq_msg = pre_ex_res # type: ignore
303319
result = await run_task(
304-
broker.available_tasks[message.task_name].original_func,
305-
task_signatures.get(message.task_name),
306-
taskiq_msg,
307-
cli_args.log_collector_format,
308-
executor,
320+
target=broker.available_tasks[message.task_name].original_func,
321+
signature=task_signatures.get(message.task_name),
322+
message=taskiq_msg,
323+
log_collector_format=cli_args.log_collector_format,
324+
executor=executor,
325+
middlewares=broker.middlewares,
309326
)
310327
for middleware in broker.middlewares:
311-
post_ex_res = middleware.post_execute(
312-
result,
313-
broker.available_tasks[message.task_name].labels,
314-
)
328+
post_ex_res = middleware.post_execute(taskiq_msg, result)
315329
if inspect.isawaitable(post_ex_res):
316330
await post_ex_res
317331
try:

taskiq/formatters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Taskiq formatters."""

taskiq/plugins/json_formatter.py renamed to taskiq/formatters/json_formatter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
1-
from typing import Any, Dict
2-
31
from taskiq.abc.formatter import TaskiqFormatter
42
from taskiq.message import BrokerMessage, TaskiqMessage
53

64

75
class JSONFormatter(TaskiqFormatter):
86
"""Default taskiq formatter."""
97

10-
def dumps(self, message: TaskiqMessage, labels: Dict[str, Any]) -> BrokerMessage:
8+
def dumps(self, message: TaskiqMessage) -> BrokerMessage:
119
"""
1210
Dumps taskiq message to some broker message format.
1311
1412
:param message: message to send.
15-
:param labels: message's labels.
1613
:return: Dumped message.
1714
"""
1815
return BrokerMessage(
1916
task_id=message.task_id,
2017
task_name=message.task_name,
2118
message=message.json(),
22-
headers={
19+
labels={
2320
"content_type": "application/json",
2421
},
2522
)

0 commit comments

Comments
 (0)