Skip to content

Commit 4344e6a

Browse files
committed
add: SmartRetryMiddleware
rename: middleware file
1 parent 23422c7 commit 4344e6a

File tree

5 files changed

+189
-3
lines changed

5 files changed

+189
-3
lines changed

taskiq/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,19 @@
2525
)
2626
from taskiq.funcs import gather
2727
from taskiq.message import BrokerMessage, TaskiqMessage
28-
from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware
29-
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
28+
from taskiq.middlewares import (
29+
PrometheusMiddleware,
30+
SimpleRetryMiddleware,
31+
SmartRetryMiddleware,
32+
)
3033
from taskiq.result import TaskiqResult
3134
from taskiq.scheduler.scheduled_task import ScheduledTask
3235
from taskiq.scheduler.scheduler import TaskiqScheduler
3336
from taskiq.state import TaskiqState
3437
from taskiq.task import AsyncTaskiqTask
3538

3639
__version__ = version("taskiq")
40+
3741
__all__ = [
3842
"AckableMessage",
3943
"AsyncBroker",
@@ -52,6 +56,7 @@
5256
"SecurityError",
5357
"SendTaskError",
5458
"SimpleRetryMiddleware",
59+
"SmartRetryMiddleware",
5560
"TaskiqDepends",
5661
"TaskiqError",
5762
"TaskiqEvents",

taskiq/middlewares/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,12 @@
11
"""Taskiq middlewares."""
2+
3+
4+
from .prometheus_middleware import PrometheusMiddleware
5+
from .simple_retry_middleware import SimpleRetryMiddleware
6+
from .smart_retry_middleware import SmartRetryMiddleware
7+
8+
__all__ = (
9+
"PrometheusMiddleware",
10+
"SimpleRetryMiddleware",
11+
"SmartRetryMiddleware",
12+
)
File renamed without changes.
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
__all__ = ("SmartRetryMiddleware",)
2+
3+
import datetime
4+
import random
5+
from logging import getLogger
6+
from typing import Any
7+
8+
from taskiq import ScheduleSource
9+
from taskiq.abc.middleware import TaskiqMiddleware
10+
from taskiq.exceptions import NoResultError
11+
from taskiq.kicker import AsyncKicker
12+
from taskiq.message import TaskiqMessage
13+
from taskiq.result import TaskiqResult
14+
15+
_logger = getLogger("taskiq.smart_retry_middleware")
16+
17+
18+
class SmartRetryMiddleware(TaskiqMiddleware):
19+
"""Middleware to retry tasks delays.
20+
21+
This middleware retries failed tasks with support for:
22+
- max retries
23+
- delay
24+
- jitter
25+
- exponential backoff
26+
"""
27+
28+
def __init__(
29+
self,
30+
default_retry_count: int = 3,
31+
default_retry_label: bool = False,
32+
no_result_on_retry: bool = True,
33+
default_delay: float = 5,
34+
use_jitter: bool = False,
35+
use_delay_exponent: bool = False,
36+
max_delay_exponent: float = 60,
37+
schedule_source: ScheduleSource | None = None,
38+
) -> None:
39+
"""
40+
Initialize retry middleware.
41+
42+
:param default_retry_count: Default max retries if not specified.
43+
:param default_retry_label: Whether to retry tasks by default.
44+
:param no_result_on_retry: Replace result with NoResultError on retry.
45+
:param default_delay: Delay in seconds before retrying.
46+
:param use_jitter: Add random jitter to retry delay.
47+
:param use_delay_exponent: Apply exponential backoff to delay.
48+
:param max_delay_exponent: Maximum allowed delay when using backoff.
49+
:param schedule_source: Schedule source to use for scheduling.
50+
If None, the default broker will be used.
51+
"""
52+
super().__init__()
53+
self.default_retry_count = default_retry_count
54+
self.default_retry_label = default_retry_label
55+
self.no_result_on_retry = no_result_on_retry
56+
self.default_delay = default_delay
57+
self.use_jitter = use_jitter
58+
self.use_delay_exponent = use_delay_exponent
59+
self.max_delay_exponent = max_delay_exponent
60+
self.schedule_source = schedule_source
61+
62+
def is_retry_on_error(self, message: TaskiqMessage) -> bool:
63+
"""
64+
Check if retry is enabled for this task.
65+
66+
Looks for `retry_on_error` label, falls back to default.
67+
68+
:param message: Original task message.
69+
:return: True if should retry on error.
70+
"""
71+
retry_on_error = message.labels.get("retry_on_error")
72+
if isinstance(retry_on_error, str):
73+
retry_on_error = retry_on_error.lower() == "true"
74+
if retry_on_error is None:
75+
retry_on_error = self.default_retry_label
76+
return retry_on_error
77+
78+
def make_delay(self, message: TaskiqMessage, retries: int) -> float:
79+
"""
80+
Calculate retry delay.
81+
82+
Includes jitter and exponential backoff if enabled.
83+
84+
:param message: Task message.
85+
:param retries: Current retry count.
86+
:return: Delay in seconds.
87+
"""
88+
delay = float(message.labels.get("delay", self.default_delay))
89+
if self.use_delay_exponent:
90+
delay = min(delay * retries, self.max_delay_exponent)
91+
92+
if self.use_jitter:
93+
delay += random.random() # noqa: S311
94+
95+
return delay
96+
97+
async def on_send(
98+
self,
99+
kicker: AsyncKicker[Any, Any],
100+
message: TaskiqMessage,
101+
delay: float,
102+
) -> None:
103+
"""Execute the task with a delay."""
104+
if isinstance(self.schedule_source, ScheduleSource):
105+
target_time = datetime.datetime.now(datetime.UTC) + datetime.timedelta(
106+
seconds=delay,
107+
)
108+
await kicker.schedule_by_time(self.schedule_source, target_time)
109+
else:
110+
await kicker.with_labels(delay=delay).kiq(*message.args, **message.kwargs)
111+
112+
async def on_error(
113+
self,
114+
message: TaskiqMessage,
115+
result: TaskiqResult[Any],
116+
exception: BaseException,
117+
) -> None:
118+
"""
119+
Retry on error.
120+
121+
If an error is raised during task execution,
122+
this middleware schedules the task to be retried
123+
after a calculated delay.
124+
125+
:param message: Message that caused the error.
126+
:param result: Execution result.
127+
:param exception: Caught exception.
128+
"""
129+
if isinstance(exception, NoResultError):
130+
return
131+
132+
retry_on_error = self.is_retry_on_error(message)
133+
134+
if not retry_on_error:
135+
return
136+
137+
retries = int(message.labels.get("_retries", 0)) + 1
138+
max_retries = int(message.labels.get("max_retries", self.default_retry_count))
139+
140+
if retries < max_retries:
141+
delay = self.make_delay(message, retries)
142+
143+
_logger.info(
144+
"Task %s failed. Retrying %d/%d in %.2f seconds.",
145+
message.task_name,
146+
retries,
147+
max_retries,
148+
delay,
149+
)
150+
151+
kicker: AsyncKicker[Any, Any] = (
152+
AsyncKicker(
153+
task_name=message.task_name,
154+
broker=self.broker,
155+
labels=message.labels,
156+
)
157+
.with_task_id(message.task_id)
158+
.with_labels(_retries=retries)
159+
)
160+
161+
await self.on_send(kicker, message, delay)
162+
163+
if self.no_result_on_retry:
164+
result.error = NoResultError()
165+
166+
else:
167+
_logger.warning(
168+
"Task '%s' invocation failed. Maximum retries count is reached.",
169+
message.task_name,
170+
)

tests/middlewares/test_simple_retry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from taskiq.formatters.json_formatter import JSONFormatter
77
from taskiq.message import TaskiqMessage
8-
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
8+
from taskiq.middlewares.simple_retry_middleware import SimpleRetryMiddleware
99
from taskiq.result import TaskiqResult
1010

1111

0 commit comments

Comments
 (0)