3
3
from concurrent .futures import Executor
4
4
from logging import getLogger
5
5
from time import time
6
- from typing import Any , Callable , Dict , Optional , get_type_hints
6
+ from typing import Any , Callable , Dict , Optional , Set , get_type_hints
7
7
8
+ import anyio
8
9
from taskiq_dependencies import DependencyGraph
9
10
10
11
from taskiq .abc .broker import AsyncBroker
17
18
from taskiq .utils import maybe_awaitable
18
19
19
20
logger = getLogger (__name__ )
21
+ QUEUE_DONE = b"-1"
20
22
21
23
22
24
def _run_sync (target : Callable [..., Any ], message : TaskiqMessage ) -> Any :
@@ -36,12 +38,13 @@ def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
36
38
class Receiver :
37
39
"""Class that uses as a callback handler."""
38
40
39
- def __init__ (
41
+ def __init__ ( # noqa: WPS211
40
42
self ,
41
43
broker : AsyncBroker ,
42
44
executor : Optional [Executor ] = None ,
43
45
validate_params : bool = True ,
44
46
max_async_tasks : "Optional[int]" = None ,
47
+ max_prefetch : int = 0 ,
45
48
) -> None :
46
49
self .broker = broker
47
50
self .executor = executor
@@ -61,6 +64,7 @@ def __init__(
61
64
"Setting unlimited number of async tasks "
62
65
+ "can result in undefined behavior" ,
63
66
)
67
+ self .sem_prefetch = asyncio .Semaphore (max_prefetch )
64
68
65
69
async def callback ( # noqa: C901, WPS213
66
70
self ,
@@ -239,7 +243,38 @@ async def listen(self) -> None: # pragma: no cover
239
243
"""
240
244
await self .broker .startup ()
241
245
logger .info ("Listening started." )
242
- tasks = set ()
246
+ queue : asyncio .Queue [bytes ] = asyncio .Queue ()
247
+
248
+ async with anyio .create_task_group () as gr :
249
+ gr .start_soon (self .prefetcher , queue )
250
+ gr .start_soon (self .runner , queue )
251
+
252
+ async def prefetcher (self , queue : "asyncio.Queue[Any]" ) -> None :
253
+ """
254
+ Prefetch tasks data.
255
+
256
+ :param queue: queue for prefetched data.
257
+ """
258
+ iterator = self .broker .listen ()
259
+
260
+ while True :
261
+ try :
262
+ await self .sem_prefetch .acquire ()
263
+ message = await iterator .__anext__ () # noqa: WPS609
264
+ await queue .put (message )
265
+
266
+ except StopAsyncIteration :
267
+ break
268
+
269
+ await queue .put (QUEUE_DONE )
270
+
271
+ async def runner (self , queue : "asyncio.Queue[bytes]" ) -> None :
272
+ """
273
+ Run tasks.
274
+
275
+ :param queue: queue with prefetched data.
276
+ """
277
+ tasks : Set [asyncio .Task [Any ]] = set ()
243
278
244
279
def task_cb (task : "asyncio.Task[Any]" ) -> None :
245
280
"""
@@ -255,11 +290,19 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
255
290
if self .sem is not None :
256
291
self .sem .release ()
257
292
258
- async for message in self . broker . listen () :
293
+ while True :
259
294
# Waits for semaphore to be released.
260
295
if self .sem is not None :
261
296
await self .sem .acquire ()
262
- task = asyncio .create_task (self .callback (message = message , raise_err = False ))
297
+
298
+ self .sem_prefetch .release ()
299
+ message = await queue .get ()
300
+ if message is QUEUE_DONE :
301
+ break
302
+
303
+ task = asyncio .create_task (
304
+ self .callback (message = message , raise_err = False ),
305
+ )
263
306
tasks .add (task )
264
307
265
308
# We want the task to remove itself from the set when it's done.
0 commit comments