Skip to content

Commit f09afc4

Browse files
authored
feat: support for handling custom exceptions in middleware. (#476)
1 parent 0b53745 commit f09afc4

File tree

3 files changed

+126
-3
lines changed

3 files changed

+126
-3
lines changed

taskiq/middlewares/simple_retry_middleware.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from logging import getLogger
2-
from typing import Any
2+
from typing import Any, Iterable, Optional
33

44
from taskiq.abc.middleware import TaskiqMiddleware
55
from taskiq.exceptions import NoResultError
@@ -18,10 +18,12 @@ def __init__(
1818
default_retry_count: int = 3,
1919
default_retry_label: bool = False,
2020
no_result_on_retry: bool = True,
21+
types_of_exceptions: Optional[Iterable[type[BaseException]]] = None,
2122
) -> None:
2223
self.default_retry_count = default_retry_count
2324
self.default_retry_label = default_retry_label
2425
self.no_result_on_retry = no_result_on_retry
26+
self.types_of_exceptions = types_of_exceptions
2527

2628
async def on_error(
2729
self,
@@ -42,6 +44,12 @@ async def on_error(
4244
:param result: execution result.
4345
:param exception: found exception.
4446
"""
47+
if self.types_of_exceptions is not None and not isinstance(
48+
exception,
49+
tuple(self.types_of_exceptions),
50+
):
51+
return
52+
4553
# Valid exception
4654
if isinstance(exception, NoResultError):
4755
return

taskiq/middlewares/smart_retry_middleware.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22
import random
33
from logging import getLogger
4-
from typing import Any, Optional
4+
from typing import Any, Iterable, Optional
55

66
from taskiq import ScheduleSource
77
from taskiq.abc.middleware import TaskiqMiddleware
@@ -35,6 +35,7 @@ def __init__(
3535
use_delay_exponent: bool = False,
3636
max_delay_exponent: float = 60,
3737
schedule_source: Optional[ScheduleSource] = None,
38+
types_of_exceptions: Optional[Iterable[type[BaseException]]] = None,
3839
) -> None:
3940
"""
4041
Initialize retry middleware.
@@ -48,6 +49,7 @@ def __init__(
4849
:param max_delay_exponent: Maximum allowed delay when using backoff.
4950
:param schedule_source: Schedule source to use for scheduling.
5051
If None, the default broker will be used.
52+
:param types_of_exceptions: Types of exceptions to retry from.
5153
"""
5254
super().__init__()
5355
self.default_retry_count = default_retry_count
@@ -58,6 +60,7 @@ def __init__(
5860
self.use_delay_exponent = use_delay_exponent
5961
self.max_delay_exponent = max_delay_exponent
6062
self.schedule_source = schedule_source
63+
self.types_of_exceptions = types_of_exceptions
6164

6265
if not isinstance(schedule_source, (ScheduleSource, type(None))):
6366
raise TypeError(
@@ -138,6 +141,12 @@ async def on_error(
138141
:param result: Execution result.
139142
:param exception: Caught exception.
140143
"""
144+
if self.types_of_exceptions is not None and not isinstance(
145+
exception,
146+
tuple(self.types_of_exceptions),
147+
):
148+
return
149+
141150
if isinstance(exception, NoResultError):
142151
return
143152

tests/middlewares/test_task_retry.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from taskiq import InMemoryBroker, SimpleRetryMiddleware
5+
from taskiq import InMemoryBroker, SimpleRetryMiddleware, SmartRetryMiddleware
66
from taskiq.exceptions import NoResultError
77

88

@@ -151,3 +151,109 @@ def run_task() -> str:
151151

152152
assert runs == 1
153153
assert str(resp.error) == str(runs)
154+
155+
156+
@pytest.mark.anyio
157+
async def test_retry_of_custom_exc_types_of_simple_middleware() -> None:
158+
# test that the passed error will be handled
159+
broker = InMemoryBroker().with_middlewares(
160+
SimpleRetryMiddleware(
161+
no_result_on_retry=True,
162+
default_retry_label=True,
163+
types_of_exceptions=(KeyError, ValueError),
164+
),
165+
)
166+
runs = 0
167+
168+
@broker.task(max_retries=10)
169+
def run_task() -> None:
170+
nonlocal runs
171+
172+
runs += 1
173+
174+
raise ValueError(runs)
175+
176+
task = await run_task.kiq()
177+
resp = await task.wait_result(timeout=1)
178+
with pytest.raises(ValueError):
179+
resp.raise_for_error()
180+
181+
assert runs == 10
182+
183+
# test that an untransmitted error will not be handled
184+
broker = InMemoryBroker().with_middlewares(
185+
SimpleRetryMiddleware(
186+
no_result_on_retry=True,
187+
default_retry_label=True,
188+
types_of_exceptions=(KeyError,),
189+
),
190+
)
191+
runs = 0
192+
193+
@broker.task(max_retries=10)
194+
def run_task2() -> None:
195+
nonlocal runs
196+
197+
runs += 1
198+
199+
raise ValueError(runs)
200+
201+
task = await run_task2.kiq()
202+
resp = await task.wait_result(timeout=1)
203+
with pytest.raises(ValueError):
204+
resp.raise_for_error()
205+
206+
assert runs == 1
207+
208+
209+
@pytest.mark.anyio
210+
async def test_retry_of_custom_exc_types_of_smart_middleware() -> None:
211+
# test that the passed error will be handled
212+
broker = InMemoryBroker().with_middlewares(
213+
SmartRetryMiddleware(
214+
no_result_on_retry=True,
215+
default_retry_label=True,
216+
types_of_exceptions=(KeyError, ValueError),
217+
),
218+
)
219+
runs = 0
220+
221+
@broker.task(max_retries=10)
222+
def run_task() -> None:
223+
nonlocal runs
224+
225+
runs += 1
226+
227+
raise ValueError(runs)
228+
229+
task = await run_task.kiq()
230+
resp = await task.wait_result(timeout=1)
231+
with pytest.raises(ValueError):
232+
resp.raise_for_error()
233+
234+
assert runs == 10
235+
236+
# test that an untransmitted error will not be handled
237+
broker = InMemoryBroker().with_middlewares(
238+
SmartRetryMiddleware(
239+
no_result_on_retry=True,
240+
default_retry_label=True,
241+
types_of_exceptions=(KeyError,),
242+
),
243+
)
244+
runs = 0
245+
246+
@broker.task(max_retries=10)
247+
def run_task2() -> None:
248+
nonlocal runs
249+
250+
runs += 1
251+
252+
raise ValueError(runs)
253+
254+
task = await run_task2.kiq()
255+
resp = await task.wait_result(timeout=1)
256+
with pytest.raises(ValueError):
257+
resp.raise_for_error()
258+
259+
assert runs == 1

0 commit comments

Comments
 (0)