33from concurrent .futures import Executor
44from logging import getLogger
55from time import time
6- from typing import Any , Callable , Dict , Optional , Set , get_type_hints
6+ from typing import Any , Callable , Dict , Optional , Set , Union , get_type_hints
77
88import anyio
99from taskiq_dependencies import DependencyGraph
1010
11- from taskiq .abc .broker import AsyncBroker
11+ from taskiq .abc .broker import AckableMessage , AsyncBroker
1212from taskiq .abc .middleware import TaskiqMiddleware
1313from taskiq .context import Context
14- from taskiq .exceptions import NoResultError
14+ from taskiq .exceptions import NoResultError , RejectError
1515from taskiq .message import TaskiqMessage
1616from taskiq .receiver .params_parser import parse_params
1717from taskiq .result import TaskiqResult
@@ -69,9 +69,9 @@ def __init__( # noqa: WPS211
6969 )
7070 self .sem_prefetch = asyncio .Semaphore (max_prefetch )
7171
72- async def callback ( # noqa: C901, WPS213
72+ async def callback ( # noqa: C901, WPS213, WPS217
7373 self ,
74- message : bytes ,
74+ message : Union [ bytes , AckableMessage ] ,
7575 raise_err : bool = False ,
7676 ) -> None :
7777 """
@@ -86,12 +86,16 @@ async def callback( # noqa: C901, WPS213
8686 :param raise_err: raise an error if cannot save result in
8787 result_backend.
8888 """
89+ if isinstance (message , AckableMessage ):
90+ message_data = message .data
91+ else :
92+ message_data = message
8993 try :
90- taskiq_msg = self .broker .formatter .loads (message = message )
94+ taskiq_msg = self .broker .formatter .loads (message = message_data )
9195 except Exception as exc :
9296 logger .warning (
9397 "Cannot parse message: %s. Skipping execution.\n %s" ,
94- message ,
98+ message_data ,
9599 exc ,
96100 exc_info = True ,
97101 )
@@ -124,9 +128,20 @@ async def callback( # noqa: C901, WPS213
124128 target = self .broker .available_tasks [taskiq_msg .task_name ].original_func ,
125129 message = taskiq_msg ,
126130 )
131+
132+ # If broker has an ability to ack or reject messages.
133+ if isinstance (message , AckableMessage ):
134+ # If we received an error for negative acknowledgement.
135+ if message .reject is not None and isinstance (result .error , RejectError ):
136+ await maybe_awaitable (message .reject ())
137+ # Otherwise we positively acknowledge the message.
138+ else :
139+ await maybe_awaitable (message .ack ())
140+
127141 for middleware in self .broker .middlewares :
128142 if middleware .__class__ .post_execute != TaskiqMiddleware .post_execute :
129143 await maybe_awaitable (middleware .post_execute (taskiq_msg , result ))
144+
130145 try :
131146 if not isinstance (result .error , NoResultError ):
132147 await self .broker .result_backend .set_result (taskiq_msg .task_id , result )
@@ -255,13 +270,16 @@ async def listen(self) -> None: # pragma: no cover
255270 """
256271 await self .broker .startup ()
257272 logger .info ("Listening started." )
258- queue : asyncio .Queue [bytes ] = asyncio .Queue ()
273+ queue : " asyncio.Queue[Union[ bytes, AckableMessage]]" = asyncio .Queue ()
259274
260275 async with anyio .create_task_group () as gr :
261276 gr .start_soon (self .prefetcher , queue )
262277 gr .start_soon (self .runner , queue )
263278
264- async def prefetcher (self , queue : "asyncio.Queue[Any]" ) -> None :
279+ async def prefetcher (
280+ self ,
281+ queue : "asyncio.Queue[Union[bytes, AckableMessage]]" ,
282+ ) -> None :
265283 """
266284 Prefetch tasks data.
267285
@@ -280,7 +298,10 @@ async def prefetcher(self, queue: "asyncio.Queue[Any]") -> None:
280298
281299 await queue .put (QUEUE_DONE )
282300
283- async def runner (self , queue : "asyncio.Queue[bytes]" ) -> None :
301+ async def runner (
302+ self ,
303+ queue : "asyncio.Queue[Union[bytes, AckableMessage]]" ,
304+ ) -> None :
284305 """
285306 Run tasks.
286307
0 commit comments