Skip to content

Commit 8789d9c

Browse files
authored
Added ack config. (#249)
1 parent fb27ddf commit 8789d9c

File tree

5 files changed

+52
-3
lines changed

5 files changed

+52
-3
lines changed

taskiq/acks.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
1+
import enum
12
from typing import Awaitable, Callable, Union
23

34
from pydantic import BaseModel
45

56

7+
@enum.unique
8+
class AcknowledgeType(str, enum.Enum):
9+
"""Enum with possible acknowledge times."""
10+
11+
# The message is acknowledged right when it's received,
12+
# before it's executed.
13+
WHEN_RECEIVED = "when_received"
14+
# This option means that the message will be
15+
# acknowledged right after it's executed.
16+
WHEN_EXECUTED = "when_executed"
17+
# This option means that the message will be
18+
# acknowledged when the task will be saved
19+
# only after it's saved in the result backend.
20+
WHEN_SAVED = "when_saved"
21+
22+
623
class AckableMessage(BaseModel):
724
"""
825
Message that can be acknowledged.

taskiq/api/receiver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import asyncio
22
from concurrent.futures import ThreadPoolExecutor
33
from logging import getLogger
4-
from typing import Type
4+
from typing import Optional, Type
55

66
from taskiq.abc.broker import AsyncBroker
7+
from taskiq.acks import AcknowledgeType
78
from taskiq.receiver.receiver import Receiver
89

910
logger = getLogger("taskiq.receiver")
@@ -18,6 +19,7 @@ async def run_receiver_task(
1819
max_prefetch: int = 0,
1920
propagate_exceptions: bool = True,
2021
run_startup: bool = False,
22+
ack_time: Optional[AcknowledgeType] = None,
2123
) -> None:
2224
"""
2325
Function to run receiver programmatically.
@@ -71,6 +73,7 @@ def on_exit(_: Receiver) -> None:
7173
max_prefetch=max_prefetch,
7274
propagate_exceptions=propagate_exceptions,
7375
on_exit=on_exit,
76+
ack_type=ack_time,
7477
)
7578
await receiver.listen()
7679
except asyncio.CancelledError:

taskiq/cli/worker/args.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass, field
33
from typing import List, Optional, Sequence, Tuple
44

5+
from taskiq.acks import AcknowledgeType
56
from taskiq.cli.common_args import LogLevel
67

78

@@ -41,6 +42,7 @@ class WorkerArgs:
4142
max_prefetch: int = 0
4243
no_propagate_errors: bool = False
4344
max_fails: int = -1
45+
ack_type: AcknowledgeType = AcknowledgeType.WHEN_SAVED
4446

4547
@classmethod
4648
def from_cli(
@@ -187,6 +189,13 @@ def from_cli(
187189
default=-1,
188190
help="Maximum number of child process exits.",
189191
)
192+
parser.add_argument(
193+
"--ack-type",
194+
type=lambda value: AcknowledgeType(value.lower()),
195+
default=AcknowledgeType.WHEN_SAVED,
196+
choices=[ack_type.name.lower() for ack_type in AcknowledgeType],
197+
help="When to acknowledge message.",
198+
)
190199

191200
namespace = parser.parse_args(args)
192201
return WorkerArgs(**namespace.__dict__)

taskiq/cli/worker/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
141141
max_async_tasks=args.max_async_tasks,
142142
max_prefetch=args.max_prefetch,
143143
propagate_exceptions=not args.no_propagate_errors,
144+
ack_type=args.ack_type,
144145
**receiver_kwargs, # type: ignore
145146
)
146147
loop.run_until_complete(receiver.listen())

taskiq/receiver/receiver.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from taskiq.abc.broker import AckableMessage, AsyncBroker
1212
from taskiq.abc.middleware import TaskiqMiddleware
13+
from taskiq.acks import AcknowledgeType
1314
from taskiq.context import Context
1415
from taskiq.exceptions import NoResultError
1516
from taskiq.message import TaskiqMessage
@@ -53,6 +54,7 @@ def __init__(
5354
max_prefetch: int = 0,
5455
propagate_exceptions: bool = True,
5556
run_starup: bool = True,
57+
ack_type: Optional[AcknowledgeType] = None,
5658
on_exit: Optional[Callable[["Receiver"], None]] = None,
5759
) -> None:
5860
self.broker = broker
@@ -64,6 +66,7 @@ def __init__(
6466
self.dependency_graphs: Dict[str, DependencyGraph] = {}
6567
self.propagate_exceptions = propagate_exceptions
6668
self.on_exit = on_exit
69+
self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED
6770
self.known_tasks: Set[str] = set()
6871
for task in self.broker.get_all_tasks().values():
6972
self._prepare_task(task.task_name, task.original_func)
@@ -131,13 +134,21 @@ async def callback( # noqa: C901, PLR0912
131134
taskiq_msg.task_id,
132135
)
133136

137+
if self.ack_time == AcknowledgeType.WHEN_RECEIVED and isinstance(
138+
message,
139+
AckableMessage,
140+
):
141+
await maybe_awaitable(message.ack())
142+
134143
result = await self.run_task(
135144
target=task.original_func,
136145
message=taskiq_msg,
137146
)
138147

139-
# If broker has an ability to ack messages.
140-
if isinstance(message, AckableMessage):
148+
if self.ack_time == AcknowledgeType.WHEN_EXECUTED and isinstance(
149+
message,
150+
AckableMessage,
151+
):
141152
await maybe_awaitable(message.ack())
142153

143154
for middleware in self.broker.middlewares:
@@ -147,9 +158,11 @@ async def callback( # noqa: C901, PLR0912
147158
try:
148159
if not isinstance(result.error, NoResultError):
149160
await self.broker.result_backend.set_result(taskiq_msg.task_id, result)
161+
150162
for middleware in self.broker.middlewares:
151163
if middleware.__class__.post_save != TaskiqMiddleware.post_save:
152164
await maybe_awaitable(middleware.post_save(taskiq_msg, result))
165+
153166
except Exception as exc:
154167
logger.exception(
155168
"Can't set result in result backend. Cause: %s",
@@ -159,6 +172,12 @@ async def callback( # noqa: C901, PLR0912
159172
if raise_err:
160173
raise exc
161174

175+
if self.ack_time == AcknowledgeType.WHEN_SAVED and isinstance(
176+
message,
177+
AckableMessage,
178+
):
179+
await maybe_awaitable(message.ack())
180+
162181
async def run_task( # noqa: C901, PLR0912, PLR0915
163182
self,
164183
target: Callable[..., Any],

0 commit comments

Comments
 (0)