Skip to content

Commit a488174

Browse files
Sobes76rusAnton
andauthored
fix: retry_middleware (#126)
Co-authored-by: Anton <[email protected]>
1 parent bd1bd9d commit a488174

File tree

3 files changed

+170
-4
lines changed

3 files changed

+170
-4
lines changed

taskiq/middlewares/retry_middleware.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any
44

55
from taskiq.abc.middleware import TaskiqMiddleware
6+
from taskiq.exceptions import NoResultError
67
from taskiq.message import TaskiqMessage
78
from taskiq.result import TaskiqResult
89

@@ -15,8 +16,12 @@ class SimpleRetryMiddleware(TaskiqMiddleware):
1516
def __init__(
1617
self,
1718
default_retry_count: int = 3,
19+
default_retry_label: bool = False,
20+
no_result_on_retry: bool = True,
1821
) -> None:
1922
self.default_retry_count = default_retry_count
23+
self.default_retry_label = default_retry_label
24+
self.no_result_on_retry = no_result_on_retry
2025

2126
async def on_error(
2227
self,
@@ -37,24 +42,34 @@ async def on_error(
3742
:param result: execution result.
3843
:param exception: found exception.
3944
"""
45+
# Valid exception
46+
if isinstance(exception, NoResultError):
47+
return
48+
4049
retry_on_error = message.labels.get("retry_on_error")
50+
if retry_on_error is None:
51+
retry_on_error = "true" if self.default_retry_label else "false"
4152
# Check if retrying is enabled for the task.
42-
if retry_on_error is None or retry_on_error.lower() != "true":
53+
if retry_on_error.lower() != "true":
4354
return
4455
new_msg = deepcopy(message)
56+
4557
# Getting number of previous retries.
4658
retries = int(new_msg.labels.get("_retries", 0)) + 1
4759
new_msg.labels["_retries"] = str(retries)
4860
max_retries = int(new_msg.labels.get("max_retries", self.default_retry_count))
61+
4962
if retries < max_retries:
5063
logger.info(
5164
"Task '%s' invocation failed. Retrying.",
5265
message.task_name,
5366
)
54-
new_msg.labels["_parent"] = message.task_id
55-
new_msg.task_id = self.broker.id_generator()
5667
broker_message = self.broker.formatter.dumps(message=new_msg)
5768
await self.broker.kick(broker_message)
69+
70+
if self.no_result_on_retry:
71+
result.error = NoResultError()
72+
5873
else:
5974
logger.warning(
6075
"Task '%s' invocation failed. Maximum retries count is reached.",

tests/middlewares/test_simple_retry.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ async def test_successful_retry(broker: AsyncMock) -> None:
3737
resend: TaskiqMessage = broker.kick.await_args.args[0]
3838
assert resend.task_name == "meme"
3939
assert resend.labels["_retries"] == "1"
40-
assert resend.labels["_parent"] == "test_id"
4140

4241

4342
@pytest.mark.anyio

tests/middlewares/test_task_retry.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import asyncio
2+
import time
3+
4+
import pytest
5+
6+
from taskiq import InMemoryBroker, SimpleRetryMiddleware
7+
from taskiq.exceptions import NoResultError
8+
9+
10+
@pytest.mark.anyio
11+
async def test_wait_result() -> None:
12+
"""Tests wait_result."""
13+
14+
broker = InMemoryBroker().with_middlewares(
15+
SimpleRetryMiddleware(no_result_on_retry=True),
16+
)
17+
runs = 0
18+
19+
@broker.task(retry_on_error=True)
20+
def run_task() -> str:
21+
nonlocal runs # noqa: WPS420
22+
23+
if runs == 0:
24+
runs += 1
25+
raise Exception("Retry")
26+
27+
time.sleep(0.2)
28+
return "hello world!"
29+
30+
task = await run_task.kiq()
31+
resp = await task.wait_result(0.1, timeout=1)
32+
33+
assert resp.return_value == "hello world!"
34+
35+
36+
@pytest.mark.anyio
37+
async def test_wait_result_error() -> None:
38+
"""Tests wait_result."""
39+
40+
broker = InMemoryBroker().with_middlewares(
41+
SimpleRetryMiddleware(no_result_on_retry=False),
42+
)
43+
runs = 0
44+
45+
@broker.task(retry_on_error=True)
46+
def run_task() -> str:
47+
nonlocal runs # noqa: WPS420
48+
49+
if runs == 0:
50+
runs += 1
51+
raise ValueError("Retry")
52+
53+
time.sleep(0.2)
54+
return "hello world!"
55+
56+
task = await run_task.kiq()
57+
resp = await task.wait_result(0.1, timeout=1)
58+
with pytest.raises(ValueError):
59+
resp.raise_for_error()
60+
61+
await asyncio.sleep(0.2)
62+
resp = await task.wait_result(timeout=1)
63+
assert resp.return_value == "hello world!"
64+
65+
66+
@pytest.mark.anyio
67+
async def test_wait_result_no_result() -> None:
68+
"""Tests wait_result."""
69+
70+
broker = InMemoryBroker().with_middlewares(
71+
SimpleRetryMiddleware(no_result_on_retry=False),
72+
)
73+
done = False
74+
runs = 0
75+
76+
@broker.task(retry_on_error=True)
77+
def run_task() -> str:
78+
nonlocal runs, done # noqa: WPS420
79+
80+
if runs == 0:
81+
runs += 1
82+
raise ValueError("Retry")
83+
84+
time.sleep(0.2)
85+
done = True
86+
raise NoResultError()
87+
88+
task = await run_task.kiq()
89+
resp = await task.wait_result(0.1, timeout=1)
90+
with pytest.raises(ValueError):
91+
resp.raise_for_error()
92+
93+
await asyncio.sleep(0.2)
94+
resp = await task.wait_result(timeout=1)
95+
with pytest.raises(ValueError):
96+
resp.raise_for_error()
97+
98+
assert done
99+
100+
101+
@pytest.mark.anyio
102+
async def test_max_retries() -> None:
103+
"""Tests wait_result."""
104+
105+
broker = InMemoryBroker().with_middlewares(
106+
SimpleRetryMiddleware(
107+
no_result_on_retry=True,
108+
default_retry_label=True,
109+
),
110+
)
111+
runs = 0
112+
113+
@broker.task(max_retries=10)
114+
def run_task() -> str:
115+
nonlocal runs # noqa: WPS420
116+
117+
runs += 1
118+
raise ValueError(runs)
119+
120+
task = await run_task.kiq()
121+
resp = await task.wait_result(timeout=1)
122+
with pytest.raises(ValueError):
123+
resp.raise_for_error()
124+
125+
assert runs == 10
126+
assert str(resp.error) == str(runs)
127+
128+
129+
@pytest.mark.anyio
130+
async def test_no_retry() -> None:
131+
broker = InMemoryBroker().with_middlewares(
132+
SimpleRetryMiddleware(
133+
no_result_on_retry=True,
134+
default_retry_label=True,
135+
),
136+
)
137+
runs = 0
138+
139+
@broker.task(retry_on_error=False, max_retries=10)
140+
def run_task() -> str:
141+
nonlocal runs # noqa: WPS420
142+
143+
runs += 1
144+
raise ValueError(runs)
145+
146+
task = await run_task.kiq()
147+
resp = await task.wait_result(timeout=1)
148+
with pytest.raises(ValueError):
149+
resp.raise_for_error()
150+
151+
assert runs == 1
152+
assert str(resp.error) == str(runs)

0 commit comments

Comments
 (0)