Skip to content

Commit e67dab3

Browse files
Sobes76rusAnton
andauthored
feat: receiver max_prefetch argument (#127)
Co-authored-by: Anton <[email protected]>
1 parent 72830e9 commit e67dab3

File tree

3 files changed

+57
-5
lines changed

3 files changed

+57
-5
lines changed

taskiq/cli/worker/args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class WorkerArgs:
4040
max_async_tasks: int = 100
4141
receiver: str = "taskiq.receiver:Receiver"
4242
receiver_arg: List[Tuple[str, str]] = field(default_factory=list)
43+
max_prefetch: int = 0
4344

4445
@classmethod
4546
def from_cli( # noqa: WPS213
@@ -168,6 +169,13 @@ def from_cli( # noqa: WPS213
168169
default=100,
169170
help="Maximum simultaneous async tasks per worker process. ",
170171
)
172+
parser.add_argument(
173+
"--max-prefetch",
174+
type=int,
175+
dest="max_prefetch",
176+
default=0,
177+
help="Maximum prefetched tasks per worker process. ",
178+
)
171179

172180
namespace = parser.parse_args(args)
173181
return WorkerArgs(**namespace.__dict__)

taskiq/cli/worker/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
132132
executor=pool,
133133
validate_params=not args.no_parse,
134134
max_async_tasks=args.max_async_tasks,
135+
max_prefetch=args.max_prefetch,
135136
**receiver_args,
136137
)
137138
loop.run_until_complete(receiver.listen())

taskiq/receiver/receiver.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from concurrent.futures import Executor
44
from logging import getLogger
55
from time import time
6-
from typing import Any, Callable, Dict, Optional, get_type_hints
6+
from typing import Any, Callable, Dict, Optional, Set, get_type_hints
77

8+
import anyio
89
from taskiq_dependencies import DependencyGraph
910

1011
from taskiq.abc.broker import AsyncBroker
@@ -17,6 +18,7 @@
1718
from taskiq.utils import maybe_awaitable
1819

1920
logger = getLogger(__name__)
21+
QUEUE_DONE = b"-1"
2022

2123

2224
def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
@@ -36,12 +38,13 @@ def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
3638
class Receiver:
3739
"""Class that uses as a callback handler."""
3840

39-
def __init__(
41+
def __init__( # noqa: WPS211
4042
self,
4143
broker: AsyncBroker,
4244
executor: Optional[Executor] = None,
4345
validate_params: bool = True,
4446
max_async_tasks: "Optional[int]" = None,
47+
max_prefetch: int = 0,
4548
) -> None:
4649
self.broker = broker
4750
self.executor = executor
@@ -61,6 +64,7 @@ def __init__(
6164
"Setting unlimited number of async tasks "
6265
+ "can result in undefined behavior",
6366
)
67+
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
6468

6569
async def callback( # noqa: C901, WPS213
6670
self,
@@ -239,7 +243,38 @@ async def listen(self) -> None: # pragma: no cover
239243
"""
240244
await self.broker.startup()
241245
logger.info("Listening started.")
242-
tasks = set()
246+
queue: asyncio.Queue[bytes] = asyncio.Queue()
247+
248+
async with anyio.create_task_group() as gr:
249+
gr.start_soon(self.prefetcher, queue)
250+
gr.start_soon(self.runner, queue)
251+
252+
async def prefetcher(self, queue: "asyncio.Queue[Any]") -> None:
253+
"""
254+
Prefetch tasks data.
255+
256+
:param queue: queue for prefetched data.
257+
"""
258+
iterator = self.broker.listen()
259+
260+
while True:
261+
try:
262+
await self.sem_prefetch.acquire()
263+
message = await iterator.__anext__() # noqa: WPS609
264+
await queue.put(message)
265+
266+
except StopAsyncIteration:
267+
break
268+
269+
await queue.put(QUEUE_DONE)
270+
271+
async def runner(self, queue: "asyncio.Queue[bytes]") -> None:
272+
"""
273+
Run tasks.
274+
275+
:param queue: queue with prefetched data.
276+
"""
277+
tasks: Set[asyncio.Task[Any]] = set()
243278

244279
def task_cb(task: "asyncio.Task[Any]") -> None:
245280
"""
@@ -255,11 +290,19 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
255290
if self.sem is not None:
256291
self.sem.release()
257292

258-
async for message in self.broker.listen():
293+
while True:
259294
# Waits for semaphore to be released.
260295
if self.sem is not None:
261296
await self.sem.acquire()
262-
task = asyncio.create_task(self.callback(message=message, raise_err=False))
297+
298+
self.sem_prefetch.release()
299+
message = await queue.get()
300+
if message is QUEUE_DONE:
301+
break
302+
303+
task = asyncio.create_task(
304+
self.callback(message=message, raise_err=False),
305+
)
263306
tasks.add(task)
264307

265308
# We want the task to remove itself from the set when it's done.

0 commit comments

Comments
 (0)