Skip to content

Commit e335493

Browse files
committed
add: SmartRetryMiddleware
rename: middleware file
1 parent 23422c7 commit e335493

File tree

5 files changed

+196
-3
lines changed

5 files changed

+196
-3
lines changed

taskiq/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,19 @@
2525
)
2626
from taskiq.funcs import gather
2727
from taskiq.message import BrokerMessage, TaskiqMessage
28-
from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware
29-
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
28+
from taskiq.middlewares import (
29+
PrometheusMiddleware,
30+
SimpleRetryMiddleware,
31+
SmartRetryMiddleware,
32+
)
3033
from taskiq.result import TaskiqResult
3134
from taskiq.scheduler.scheduled_task import ScheduledTask
3235
from taskiq.scheduler.scheduler import TaskiqScheduler
3336
from taskiq.state import TaskiqState
3437
from taskiq.task import AsyncTaskiqTask
3538

3639
__version__ = version("taskiq")
40+
3741
__all__ = [
3842
"AckableMessage",
3943
"AsyncBroker",
@@ -52,6 +56,7 @@
5256
"SecurityError",
5357
"SendTaskError",
5458
"SimpleRetryMiddleware",
59+
"SmartRetryMiddleware",
5560
"TaskiqDepends",
5661
"TaskiqError",
5762
"TaskiqEvents",

taskiq/middlewares/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,12 @@
11
"""Taskiq middlewares."""
2+
3+
4+
from .prometheus_middleware import PrometheusMiddleware
5+
from .simple_retry_middleware import SimpleRetryMiddleware
6+
from .smart_retry_middleware import SmartRetryMiddleware
7+
8+
__all__ = (
9+
"PrometheusMiddleware",
10+
"SimpleRetryMiddleware",
11+
"SmartRetryMiddleware",
12+
)
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
from __future__ import annotations
2+
3+
import datetime
4+
import random
5+
from logging import getLogger
6+
from typing import Any
7+
8+
from taskiq import ScheduleSource
9+
from taskiq.abc.middleware import TaskiqMiddleware
10+
from taskiq.exceptions import NoResultError
11+
from taskiq.kicker import AsyncKicker
12+
from taskiq.message import TaskiqMessage
13+
from taskiq.result import TaskiqResult
14+
15+
__all__ = ("SmartRetryMiddleware",)
16+
17+
_logger = getLogger("taskiq.smart_retry_middleware")
18+
19+
20+
class SmartRetryMiddleware(TaskiqMiddleware):
21+
"""Middleware to retry tasks delays.
22+
23+
This middleware retries failed tasks with support for:
24+
- max retries
25+
- delay
26+
- jitter
27+
- exponential backoff
28+
"""
29+
30+
def __init__(
31+
self,
32+
default_retry_count: int = 3,
33+
default_retry_label: bool = False,
34+
no_result_on_retry: bool = True,
35+
default_delay: float = 5,
36+
use_jitter: bool = False,
37+
use_delay_exponent: bool = False,
38+
max_delay_exponent: float = 60,
39+
schedule_source: ScheduleSource | None = None,
40+
) -> None:
41+
"""
42+
Initialize retry middleware.
43+
44+
:param default_retry_count: Default max retries if not specified.
45+
:param default_retry_label: Whether to retry tasks by default.
46+
:param no_result_on_retry: Replace result with NoResultError on retry.
47+
:param default_delay: Delay in seconds before retrying.
48+
:param use_jitter: Add random jitter to retry delay.
49+
:param use_delay_exponent: Apply exponential backoff to delay.
50+
:param max_delay_exponent: Maximum allowed delay when using backoff.
51+
:param schedule_source: Schedule source to use for scheduling.
52+
If None, the default broker will be used.
53+
"""
54+
super().__init__()
55+
self.default_retry_count = default_retry_count
56+
self.default_retry_label = default_retry_label
57+
self.no_result_on_retry = no_result_on_retry
58+
self.default_delay = default_delay
59+
self.use_jitter = use_jitter
60+
self.use_delay_exponent = use_delay_exponent
61+
self.max_delay_exponent = max_delay_exponent
62+
self.schedule_source = schedule_source
63+
64+
def is_retry_on_error(self, message: TaskiqMessage) -> bool:
65+
"""
66+
Check if retry is enabled for this task.
67+
68+
Looks for `retry_on_error` label, falls back to default.
69+
70+
:param message: Original task message.
71+
:return: True if should retry on error.
72+
"""
73+
retry_on_error = message.labels.get("retry_on_error")
74+
if isinstance(retry_on_error, str):
75+
retry_on_error = retry_on_error.lower() == "true"
76+
if retry_on_error is None:
77+
retry_on_error = self.default_retry_label
78+
return retry_on_error
79+
80+
def make_delay(self, message: TaskiqMessage, retries: int) -> float:
81+
"""
82+
Calculate retry delay.
83+
84+
Includes jitter and exponential backoff if enabled.
85+
86+
:param message: Task message.
87+
:param retries: Current retry count.
88+
:return: Delay in seconds.
89+
"""
90+
delay = float(message.labels.get("delay", self.default_delay))
91+
if self.use_delay_exponent:
92+
delay = min(delay * retries, self.max_delay_exponent)
93+
94+
if self.use_jitter:
95+
delay += random.random() # noqa: S311
96+
97+
return delay
98+
99+
async def on_send(
100+
self,
101+
kicker: AsyncKicker[Any, Any],
102+
message: TaskiqMessage,
103+
delay: float,
104+
) -> None:
105+
"""Execute the task with a delay."""
106+
if isinstance(self.schedule_source, ScheduleSource):
107+
target_time = datetime.datetime.now(datetime.UTC) + datetime.timedelta(
108+
seconds=delay,
109+
)
110+
await kicker.schedule_by_time(
111+
self.schedule_source,
112+
target_time,
113+
*message.args,
114+
**message.kwargs,
115+
)
116+
else:
117+
await kicker.with_labels(delay=delay).kiq(*message.args, **message.kwargs)
118+
119+
async def on_error(
120+
self,
121+
message: TaskiqMessage,
122+
result: TaskiqResult[Any],
123+
exception: BaseException,
124+
) -> None:
125+
"""
126+
Retry on error.
127+
128+
If an error is raised during task execution,
129+
this middleware schedules the task to be retried
130+
after a calculated delay.
131+
132+
:param message: Message that caused the error.
133+
:param result: Execution result.
134+
:param exception: Caught exception.
135+
"""
136+
if isinstance(exception, NoResultError):
137+
return
138+
139+
retry_on_error = self.is_retry_on_error(message)
140+
141+
if not retry_on_error:
142+
return
143+
144+
retries = int(message.labels.get("_retries", 0)) + 1
145+
max_retries = int(message.labels.get("max_retries", self.default_retry_count))
146+
147+
if retries < max_retries:
148+
delay = self.make_delay(message, retries)
149+
150+
_logger.info(
151+
"Task %s failed. Retrying %d/%d in %.2f seconds.",
152+
message.task_name,
153+
retries,
154+
max_retries,
155+
delay,
156+
)
157+
158+
kicker: AsyncKicker[Any, Any] = (
159+
AsyncKicker(
160+
task_name=message.task_name,
161+
broker=self.broker,
162+
labels=message.labels,
163+
)
164+
.with_task_id(message.task_id)
165+
.with_labels(_retries=retries)
166+
)
167+
168+
await self.on_send(kicker, message, delay)
169+
170+
if self.no_result_on_retry:
171+
result.error = NoResultError()
172+
173+
else:
174+
_logger.warning(
175+
"Task '%s' invocation failed. Maximum retries count is reached.",
176+
message.task_name,
177+
)

tests/middlewares/test_simple_retry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from taskiq.formatters.json_formatter import JSONFormatter
77
from taskiq.message import TaskiqMessage
8-
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
8+
from taskiq.middlewares.simple_retry_middleware import SimpleRetryMiddleware
99
from taskiq.result import TaskiqResult
1010

1111

0 commit comments

Comments
 (0)