Skip to content

Commit 1625972

Browse files
authored
Added acknowledgable messages. (#144)
1 parent e56cf1d commit 1625972

File tree

9 files changed

+269
-24
lines changed

9 files changed

+269
-24
lines changed

docs/examples/extending/broker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import AsyncGenerator
1+
from typing import AsyncGenerator, Union
22

3-
from taskiq import AsyncBroker, BrokerMessage
3+
from taskiq import AckableMessage, AsyncBroker, BrokerMessage
44

55

66
class MyBroker(AsyncBroker):
@@ -23,7 +23,7 @@ async def kick(self, message: BrokerMessage) -> None:
2323
# Send a message.message.
2424
pass
2525

26-
async def listen(self) -> AsyncGenerator[bytes, None]:
26+
async def listen(self) -> AsyncGenerator[Union[bytes, AckableMessage], None]:
2727
while True:
2828
# Get new message.
2929
new_message: bytes = ... # type: ignore

docs/extending-taskiq/broker.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,28 @@ As a broker developer, please send only raw bytes from the `message` field of a
2323
:::
2424

2525

26-
The `listen` method should yield raw bytes that were sent over the network.
26+
## Acknowledgement
27+
28+
The `listen` method should yield raw bytes of a message.
29+
But if your broker supports acking or rejecting messages, the broker should return `taskiq.AckableMessage`
30+
with required fields.
31+
32+
For example:
33+
34+
```python
35+
36+
async def listen(self) -> AsyncGenerator[AckableMessage, None]:
37+
for message in self.my_channel:
38+
yield AckableMessage(
39+
data=message.bytes,
40+
# Ack is a function that takes no parameters.
41+
# So you either set here method of a message,
42+
# or you can make a closure.
43+
ack=message.ack
44+
# Can be set to None if broker doesn't support it.
45+
reject=message.reject
46+
)
47+
```
2748

2849
## Conventions
2950

taskiq/__init__.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,22 @@
66
from taskiq.abc.middleware import TaskiqMiddleware
77
from taskiq.abc.result_backend import AsyncResultBackend
88
from taskiq.abc.schedule_source import ScheduleSource
9+
from taskiq.acks import AckableMessage
910
from taskiq.brokers.inmemory_broker import InMemoryBroker
1011
from taskiq.brokers.shared_broker import async_shared_broker
1112
from taskiq.brokers.zmq_broker import ZeroMQBroker
1213
from taskiq.context import Context
1314
from taskiq.events import TaskiqEvents
14-
from taskiq.exceptions import TaskiqError
15+
from taskiq.exceptions import (
16+
NoResultError,
17+
RejectError,
18+
ResultGetError,
19+
ResultIsReadyError,
20+
SecurityError,
21+
SendTaskError,
22+
TaskiqError,
23+
TaskiqResultTimeoutError,
24+
)
1525
from taskiq.funcs import gather
1626
from taskiq.message import BrokerMessage, TaskiqMessage
1727
from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware
@@ -26,23 +36,31 @@
2636
"Context",
2737
"AsyncBroker",
2838
"TaskiqError",
39+
"RejectError",
2940
"TaskiqState",
3041
"TaskiqResult",
3142
"ZeroMQBroker",
3243
"TaskiqEvents",
44+
"SecurityError",
3345
"TaskiqMessage",
3446
"BrokerMessage",
47+
"ResultGetError",
3548
"ScheduledTask",
3649
"TaskiqDepends",
50+
"NoResultError",
51+
"SendTaskError",
52+
"AckableMessage",
3753
"InMemoryBroker",
3854
"ScheduleSource",
3955
"TaskiqScheduler",
4056
"TaskiqFormatter",
4157
"AsyncTaskiqTask",
4258
"TaskiqMiddleware",
59+
"ResultIsReadyError",
4360
"AsyncResultBackend",
4461
"async_shared_broker",
45-
"AsyncTaskiqDecoratedTask",
46-
"SimpleRetryMiddleware",
4762
"PrometheusMiddleware",
63+
"SimpleRetryMiddleware",
64+
"AsyncTaskiqDecoratedTask",
65+
"TaskiqResultTimeoutError",
4866
]

taskiq/abc/broker.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing_extensions import ParamSpec, Self, TypeAlias
2525

2626
from taskiq.abc.middleware import TaskiqMiddleware
27+
from taskiq.acks import AckableMessage
2728
from taskiq.decor import AsyncTaskiqDecoratedTask
2829
from taskiq.events import TaskiqEvents
2930
from taskiq.formatters.json_formatter import JSONFormatter
@@ -185,13 +186,19 @@ async def kick(
185186
"""
186187

187188
@abstractmethod
188-
def listen(self) -> AsyncGenerator[bytes, None]:
189+
def listen(self) -> AsyncGenerator[Union[bytes, AckableMessage], None]:
189190
"""
190191
This function listens to new messages and yields them.
191192
192193
This it the main point for workers.
193194
This function is used to get new tasks from the network.
194195
196+
If your broker support acknowledgement, then you
197+
should wrap your message in AckableMessage dataclass.
198+
199+
If your messages was wrapped in AckableMessage dataclass,
200+
taskiq will call ack when finish processing message.
201+
195202
:yield: incoming messages.
196203
:return: nothing.
197204
"""

taskiq/acks.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import dataclasses
2+
from typing import Awaitable, Callable, Optional, Union
3+
4+
5+
@dataclasses.dataclass
6+
class AckableMessage:
7+
"""
8+
Message that can be acknowledged.
9+
10+
If your broker support message acknowledgement,
11+
please return this type of message, so we'll be
12+
able to mark this message as acknowledged after
13+
the function will be executed.
14+
15+
It adds more reliability to brokers and system
16+
as a whole.
17+
"""
18+
19+
data: bytes
20+
ack: Callable[[], Union[None, Awaitable[None]]]
21+
reject: Optional[Callable[[], Union[None, Awaitable[None]]]] = None

taskiq/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,7 @@ class SecurityError(TaskiqError):
3636

3737
class NoResultError(TaskiqError):
3838
"""Error if user does not want to set result."""
39+
40+
41+
class RejectError(TaskiqError):
42+
"""Error is thrown if message should be rejected."""

taskiq/receiver/receiver.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from concurrent.futures import Executor
44
from logging import getLogger
55
from time import time
6-
from typing import Any, Callable, Dict, Optional, Set, get_type_hints
6+
from typing import Any, Callable, Dict, Optional, Set, Union, get_type_hints
77

88
import anyio
99
from taskiq_dependencies import DependencyGraph
1010

11-
from taskiq.abc.broker import AsyncBroker
11+
from taskiq.abc.broker import AckableMessage, AsyncBroker
1212
from taskiq.abc.middleware import TaskiqMiddleware
1313
from taskiq.context import Context
14-
from taskiq.exceptions import NoResultError
14+
from taskiq.exceptions import NoResultError, RejectError
1515
from taskiq.message import TaskiqMessage
1616
from taskiq.receiver.params_parser import parse_params
1717
from taskiq.result import TaskiqResult
@@ -69,9 +69,9 @@ def __init__( # noqa: WPS211
6969
)
7070
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
7171

72-
async def callback( # noqa: C901, WPS213
72+
async def callback( # noqa: C901, WPS213, WPS217
7373
self,
74-
message: bytes,
74+
message: Union[bytes, AckableMessage],
7575
raise_err: bool = False,
7676
) -> None:
7777
"""
@@ -86,12 +86,16 @@ async def callback( # noqa: C901, WPS213
8686
:param raise_err: raise an error if cannot save result in
8787
result_backend.
8888
"""
89+
if isinstance(message, AckableMessage):
90+
message_data = message.data
91+
else:
92+
message_data = message
8993
try:
90-
taskiq_msg = self.broker.formatter.loads(message=message)
94+
taskiq_msg = self.broker.formatter.loads(message=message_data)
9195
except Exception as exc:
9296
logger.warning(
9397
"Cannot parse message: %s. Skipping execution.\n %s",
94-
message,
98+
message_data,
9599
exc,
96100
exc_info=True,
97101
)
@@ -124,9 +128,20 @@ async def callback( # noqa: C901, WPS213
124128
target=self.broker.available_tasks[taskiq_msg.task_name].original_func,
125129
message=taskiq_msg,
126130
)
131+
132+
# If broker has an ability to ack or reject messages.
133+
if isinstance(message, AckableMessage):
134+
# If we received an error for negative acknowledgement.
135+
if message.reject is not None and isinstance(result.error, RejectError):
136+
await maybe_awaitable(message.reject())
137+
# Otherwise we positively acknowledge the message.
138+
else:
139+
await maybe_awaitable(message.ack())
140+
127141
for middleware in self.broker.middlewares:
128142
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
129143
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
144+
130145
try:
131146
if not isinstance(result.error, NoResultError):
132147
await self.broker.result_backend.set_result(taskiq_msg.task_id, result)
@@ -255,13 +270,16 @@ async def listen(self) -> None: # pragma: no cover
255270
"""
256271
await self.broker.startup()
257272
logger.info("Listening started.")
258-
queue: asyncio.Queue[bytes] = asyncio.Queue()
273+
queue: "asyncio.Queue[Union[bytes, AckableMessage]]" = asyncio.Queue()
259274

260275
async with anyio.create_task_group() as gr:
261276
gr.start_soon(self.prefetcher, queue)
262277
gr.start_soon(self.runner, queue)
263278

264-
async def prefetcher(self, queue: "asyncio.Queue[Any]") -> None:
279+
async def prefetcher(
280+
self,
281+
queue: "asyncio.Queue[Union[bytes, AckableMessage]]",
282+
) -> None:
265283
"""
266284
Prefetch tasks data.
267285
@@ -280,7 +298,10 @@ async def prefetcher(self, queue: "asyncio.Queue[Any]") -> None:
280298

281299
await queue.put(QUEUE_DONE)
282300

283-
async def runner(self, queue: "asyncio.Queue[bytes]") -> None:
301+
async def runner(
302+
self,
303+
queue: "asyncio.Queue[Union[bytes, AckableMessage]]",
304+
) -> None:
284305
"""
285306
Run tasks.
286307

taskiq/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import inspect
2-
from typing import Any, Coroutine, TypeVar, Union
2+
from typing import Any, Awaitable, Coroutine, TypeVar, Union
33

44
_T = TypeVar("_T") # noqa: WPS111
55

66

77
async def maybe_awaitable(
8-
possible_coroutine: "Union[_T, Coroutine[Any, Any, _T]]",
8+
possible_coroutine: "Union[_T, Coroutine[Any, Any, _T], Awaitable[_T]]",
99
) -> _T:
1010
"""
1111
Awaits coroutine if needed.

0 commit comments

Comments
 (0)