Skip to content

Commit 8afb1c9

Browse files
authored
Added context methods to reject and requeue. (#152)
1 parent f2ed55b commit 8afb1c9

File tree

7 files changed

+93
-110
lines changed

7 files changed

+93
-110
lines changed

taskiq/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from taskiq.events import TaskiqEvents
1515
from taskiq.exceptions import (
1616
NoResultError,
17-
RejectError,
1817
ResultGetError,
1918
ResultIsReadyError,
2019
SecurityError,
@@ -45,7 +44,6 @@
4544
"Context",
4645
"AsyncBroker",
4746
"TaskiqError",
48-
"RejectError",
4947
"TaskiqState",
5048
"TaskiqResult",
5149
"ZeroMQBroker",

taskiq/acks.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import dataclasses
2-
from typing import Awaitable, Callable, Optional, Union
1+
from typing import Awaitable, Callable, Union
32

3+
from pydantic import BaseModel
44

5-
@dataclasses.dataclass
6-
class AckableMessage:
5+
6+
class AckableMessage(BaseModel):
77
"""
88
Message that can be acknowledged.
99
@@ -18,4 +18,3 @@ class AckableMessage:
1818

1919
data: bytes
2020
ack: Callable[[], Union[None, Awaitable[None]]]
21-
reject: Optional[Callable[[], Union[None, Awaitable[None]]]] = None

taskiq/context.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from copy import copy
12
from typing import TYPE_CHECKING
23

34
from taskiq.abc.broker import AsyncBroker
5+
from taskiq.exceptions import NoResultError, TaskRejectedError
46
from taskiq.message import TaskiqMessage
57

68
if TYPE_CHECKING: # pragma: no cover
@@ -15,3 +17,28 @@ def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None:
1517
self.broker = broker
1618
self.state: "TaskiqState" = None # type: ignore
1719
self.state = broker.state
20+
21+
async def requeue(self) -> None:
22+
"""
23+
Requeue task.
24+
25+
This fuction creates a task with
26+
the same message and sends it using
27+
current broker.
28+
29+
:raises NoResultError: to not store result for current task.
30+
"""
31+
message = copy(self.message)
32+
requeue_count = int(message.labels.get("X-Taskiq-requeue", 0))
33+
requeue_count += 1
34+
message.labels["X-Taskiq-requeue"] = str(requeue_count)
35+
await self.broker.kick(self.broker.formatter.dumps(self.message))
36+
raise NoResultError()
37+
38+
def reject(self) -> None:
39+
"""
40+
Raise reject error.
41+
42+
:raises TaskRejectedError: to reject current message.
43+
"""
44+
raise TaskRejectedError()

taskiq/exceptions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ class NoResultError(TaskiqError):
3838
"""Error if user does not want to set result."""
3939

4040

41-
class RejectError(TaskiqError):
42-
"""Error is thrown if message should be rejected."""
41+
class TaskRejectedError(TaskiqError):
42+
"""Task was rejected."""

taskiq/receiver/receiver.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from concurrent.futures import Executor
44
from logging import getLogger
55
from time import time
6-
from typing import Any, Callable, Dict, Optional, Set, Union, get_type_hints
6+
from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints
77

88
import anyio
99
from taskiq_dependencies import DependencyGraph
1010

1111
from taskiq.abc.broker import AckableMessage, AsyncBroker
1212
from taskiq.abc.middleware import TaskiqMiddleware
1313
from taskiq.context import Context
14-
from taskiq.exceptions import NoResultError, RejectError
14+
from taskiq.exceptions import NoResultError
1515
from taskiq.message import TaskiqMessage
1616
from taskiq.receiver.params_parser import parse_params
1717
from taskiq.result import TaskiqResult
@@ -22,18 +22,23 @@
2222
QUEUE_DONE = b"-1"
2323

2424

25-
def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
25+
def _run_sync(
26+
target: Callable[..., Any],
27+
args: List[Any],
28+
kwargs: Dict[str, Any],
29+
) -> Any:
2630
"""
2731
Runs function synchronously.
2832
2933
We use this function, because
3034
we cannot pass kwargs in loop.run_with_executor().
3135
3236
:param target: function to execute.
33-
:param message: received message from broker.
37+
:param args: list of function's args.
38+
:param kwargs: dict of function's kwargs.
3439
:return: result of function's execution.
3540
"""
36-
return target(*message.args, **message.kwargs)
41+
return target(*args, **kwargs)
3742

3843

3944
class Receiver:
@@ -124,20 +129,16 @@ async def callback( # noqa: C901, WPS213, WPS217
124129
taskiq_msg.task_name,
125130
taskiq_msg.task_id,
126131
)
132+
133+
# If broker has an ability to ack messages.
134+
if isinstance(message, AckableMessage):
135+
await maybe_awaitable(message.ack())
136+
127137
result = await self.run_task(
128138
target=self.broker.available_tasks[taskiq_msg.task_name].original_func,
129139
message=taskiq_msg,
130140
)
131141

132-
# If broker has an ability to ack or reject messages.
133-
if isinstance(message, AckableMessage):
134-
# If we received an error for negative acknowledgement.
135-
if message.reject is not None and isinstance(result.error, RejectError):
136-
await maybe_awaitable(message.reject())
137-
# Otherwise we positively acknowledge the message.
138-
else:
139-
await maybe_awaitable(message.ack())
140-
141142
for middleware in self.broker.middlewares:
142143
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
143144
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
@@ -182,14 +183,18 @@ async def run_task( # noqa: C901, WPS210
182183
"""
183184
loop = asyncio.get_running_loop()
184185
returned = None
185-
found_exception = None
186+
found_exception: "Optional[BaseException]" = None
186187
signature = None
187188
if self.validate_params:
188189
signature = self.task_signatures.get(message.task_name)
189190
dependency_graph = self.dependency_graphs.get(message.task_name)
190191
parse_params(signature, self.task_hints.get(message.task_name) or {}, message)
191192

192193
dep_ctx = None
194+
# Kwargs are defined in another variable,
195+
# because we want to update them with
196+
# kwargs resolved by dependency injector.
197+
kwargs = {}
193198
if dependency_graph:
194199
# Create a context for dependency resolving.
195200
broker_ctx = self.broker.custom_dependency_context
@@ -201,25 +206,34 @@ async def run_task( # noqa: C901, WPS210
201206
)
202207
dep_ctx = dependency_graph.async_ctx(broker_ctx)
203208
# Resolve all function's dependencies.
204-
dep_kwargs = await dep_ctx.resolve_kwargs()
205-
for key, val in dep_kwargs.items():
206-
if key not in message.kwargs:
207-
message.kwargs[key] = val
209+
kwargs = await dep_ctx.resolve_kwargs()
210+
211+
# We udpate kwargs with kwargs from network.
212+
kwargs.update(message.kwargs)
213+
208214
# Start a timer.
209215
start_time = time()
210216
try:
211-
# If the function is a coroutine we await it.
217+
# If the function is a coroutine, we await it.
212218
if asyncio.iscoroutinefunction(target):
213-
returned = await target(*message.args, **message.kwargs)
219+
returned = await target(*message.args, **kwargs)
214220
else:
215-
# If this is a synchronous function we
221+
# If this is a synchronous function, we
216222
# run it in executor.
217223
returned = await loop.run_in_executor(
218224
self.executor,
219225
_run_sync,
220226
target,
221-
message,
227+
message.args,
228+
kwargs,
222229
)
230+
except NoResultError as no_res_exc:
231+
found_exception = no_res_exc
232+
logger.warning(
233+
"Task %s with id %s skipped setting result.",
234+
message.task_name,
235+
message.task_id,
236+
)
223237
except BaseException as exc: # noqa: WPS424
224238
found_exception = exc
225239
logger.error(

tests/receiver/test_receiver.py

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from taskiq.abc.broker import AckableMessage, AsyncBroker
99
from taskiq.abc.middleware import TaskiqMiddleware
1010
from taskiq.brokers.inmemory_broker import InMemoryBroker
11-
from taskiq.exceptions import NoResultError, RejectError, TaskiqResultTimeoutError
11+
from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError
1212
from taskiq.message import TaskiqMessage
1313
from taskiq.receiver import Receiver
1414
from taskiq.result import TaskiqResult
@@ -260,83 +260,6 @@ async def ack_callback() -> None:
260260
assert acked
261261

262262

263-
@pytest.mark.anyio
264-
async def test_callback_success_reject() -> None:
265-
"""
266-
Test that if reject error is thrown,
267-
broker would reject a message.
268-
"""
269-
broker = InMemoryBroker()
270-
rejected = False
271-
272-
@broker.task
273-
async def my_task() -> None:
274-
raise RejectError()
275-
276-
def reject_callback() -> None:
277-
nonlocal rejected
278-
rejected = True
279-
280-
receiver = get_receiver(broker)
281-
282-
broker_message = broker.formatter.dumps(
283-
TaskiqMessage(
284-
task_id="task_id",
285-
task_name=my_task.task_name,
286-
labels={},
287-
args=[],
288-
kwargs={},
289-
),
290-
)
291-
292-
await receiver.callback(
293-
AckableMessage(
294-
data=broker_message.message,
295-
ack=lambda: None,
296-
reject=reject_callback,
297-
),
298-
)
299-
assert rejected
300-
301-
302-
@pytest.mark.anyio
303-
async def test_callback_no_reject_func() -> None:
304-
"""
305-
Test that if broker doesn't support rejects,
306-
it acks message instead.
307-
"""
308-
broker = InMemoryBroker()
309-
acked = False
310-
311-
@broker.task
312-
async def my_task() -> None:
313-
raise RejectError()
314-
315-
def ack_callback() -> None:
316-
nonlocal acked
317-
acked = True
318-
319-
receiver = get_receiver(broker)
320-
321-
broker_message = broker.formatter.dumps(
322-
TaskiqMessage(
323-
task_id="task_id",
324-
task_name=my_task.task_name,
325-
labels={},
326-
args=[],
327-
kwargs={},
328-
),
329-
)
330-
331-
await receiver.callback(
332-
AckableMessage(
333-
data=broker_message.message,
334-
ack=ack_callback,
335-
),
336-
)
337-
assert acked
338-
339-
340263
@pytest.mark.anyio
341264
async def test_callback_wrong_format() -> None:
342265
"""Test that wrong format of a message won't thow an error."""

tests/test_requeue.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
from taskiq import Context, InMemoryBroker, TaskiqDepends
4+
5+
6+
@pytest.mark.anyio
7+
async def test_requeue() -> None:
8+
broker = InMemoryBroker()
9+
10+
runs_count = 0
11+
12+
@broker.task
13+
async def task(context: Context = TaskiqDepends()) -> None:
14+
nonlocal runs_count
15+
runs_count += 1
16+
if runs_count < 2:
17+
await context.requeue()
18+
19+
kicked = await task.kiq()
20+
await kicked.wait_result()
21+
22+
assert runs_count == 2

0 commit comments

Comments
 (0)