Skip to content

Commit 8f74c63

Browse files
authored
Added max-tasks-per-child parameter. (#314)
1 parent ed37be4 commit 8f74c63

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

taskiq/cli/worker/args.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class WorkerArgs:
4343
no_propagate_errors: bool = False
4444
max_fails: int = -1
4545
ack_type: AcknowledgeType = AcknowledgeType.WHEN_SAVED
46+
max_tasks_per_child: Optional[int] = None
47+
wait_tasks_timeout: Optional[float] = None
4648

4749
@classmethod
4850
def from_cli(
@@ -197,6 +199,19 @@ def from_cli(
197199
choices=[ack_type.name.lower() for ack_type in AcknowledgeType],
198200
help="When to acknowledge message.",
199201
)
202+
parser.add_argument(
203+
"--max-tasks-per-child",
204+
type=int,
205+
default=None,
206+
help="Maximum number of tasks to execute per child process.",
207+
)
208+
parser.add_argument(
209+
"--wait-tasks-timeout",
210+
type=float,
211+
default=None,
212+
help="Maximum time to wait for all current tasks "
213+
"to finish before exiting.",
214+
)
200215

201216
namespace = parser.parse_args(args)
202217
# If there are any patterns specified, remove default.

taskiq/cli/worker/run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
140140
max_prefetch=args.max_prefetch,
141141
propagate_exceptions=not args.no_propagate_errors,
142142
ack_type=args.ack_type,
143+
max_tasks_to_execute=args.max_tasks_per_child,
144+
wait_tasks_timeout=args.wait_tasks_timeout,
143145
**receiver_kwargs, # type: ignore
144146
)
145147
loop.run_until_complete(receiver.listen())

taskiq/receiver/receiver.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __init__(
5656
run_starup: bool = True,
5757
ack_type: Optional[AcknowledgeType] = None,
5858
on_exit: Optional[Callable[["Receiver"], None]] = None,
59+
max_tasks_to_execute: Optional[int] = None,
60+
wait_tasks_timeout: Optional[float] = None,
5961
) -> None:
6062
self.broker = broker
6163
self.executor = executor
@@ -68,6 +70,8 @@ def __init__(
6870
self.on_exit = on_exit
6971
self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED
7072
self.known_tasks: Set[str] = set()
73+
self.max_tasks_to_execute = max_tasks_to_execute
74+
self.wait_tasks_timeout = wait_tasks_timeout
7175
for task in self.broker.get_all_tasks().values():
7276
self._prepare_task(task.task_name, task.original_func)
7377
self.sem: "Optional[asyncio.Semaphore]" = None
@@ -342,12 +346,20 @@ async def prefetcher(
342346
343347
:param queue: queue for prefetched data.
344348
"""
349+
fetched_tasks: int = 0
345350
iterator = self.broker.listen()
346351

347352
while True:
348353
try:
349354
await self.sem_prefetch.acquire()
355+
if (
356+
self.max_tasks_to_execute
357+
and fetched_tasks >= self.max_tasks_to_execute
358+
):
359+
logger.info("Max number of tasks executed.")
360+
break
350361
message = await iterator.__anext__()
362+
fetched_tasks += 1
351363
await queue.put(message)
352364
except asyncio.CancelledError:
353365
break
@@ -389,6 +401,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
389401
self.sem_prefetch.release()
390402
message = await queue.get()
391403
if message is QUEUE_DONE:
404+
logger.info("Waiting for running tasks to complete.")
405+
await asyncio.wait(tasks, timeout=self.wait_tasks_timeout)
392406
break
393407

394408
task = asyncio.create_task(

0 commit comments

Comments
 (0)