Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions taskiq/api/receiver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor
from logging import getLogger
from typing import Optional, Type

Expand All @@ -13,13 +13,14 @@
async def run_receiver_task(
broker: AsyncBroker,
receiver_cls: Type[Receiver] = Receiver,
sync_workers: int = 4,
sync_workers: Optional[int] = None,
validate_params: bool = True,
max_async_tasks: int = 100,
max_prefetch: int = 0,
propagate_exceptions: bool = True,
run_startup: bool = False,
ack_time: Optional[AcknowledgeType] = None,
use_process_pool: bool = False,
) -> None:
"""
Function to run receiver programmatically.
Expand All @@ -39,13 +40,15 @@ async def run_receiver_task(
:param broker: current broker instance.
:param receiver_cls: receiver class to use.
:param sync_workers: number of threads of a threadpool that runs sync tasks.
:param sync_workers: number of threads of a threadpool
or processes in processpool that runs sync tasks.
:param validate_params: whether to validate params or not.
:param max_async_tasks: maximum number of simultaneous async tasks.
:param max_prefetch: maximum number of tasks to prefetch.
:param propagate_exceptions: whether to propagate exceptions in generators or not.
:param run_startup: whether to run startup function or not.
:param ack_time: acknowledge type to use.
:param use_process_pool: whether to use process pool or threadpool.
:raises asyncio.CancelledError: if the task was cancelled.
"""
finish_event = asyncio.Event()
Expand All @@ -62,7 +65,12 @@ def on_exit(_: Receiver) -> None:
finish_event.set()
raise asyncio.CancelledError

with ThreadPoolExecutor(max_workers=sync_workers) as executor:
executor: Executor
if use_process_pool:
executor = ProcessPoolExecutor(max_workers=sync_workers)
else:
executor = ThreadPoolExecutor(max_workers=sync_workers)
with executor as executor:
broker.is_worker_process = True
while True:
try:
Expand Down
18 changes: 16 additions & 2 deletions taskiq/cli/worker/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class WorkerArgs:
log_level: LogLevel = LogLevel.INFO
workers: int = 2
max_threadpool_threads: int = 10
max_process_pool_processes: Optional[int] = None
no_parse: bool = False
shutdown_timeout: float = 5
reload: bool = False
Expand All @@ -46,6 +47,7 @@ class WorkerArgs:
max_tasks_per_child: Optional[int] = None
wait_tasks_timeout: Optional[float] = None
hardkill_count: int = 3
use_process_pool: bool = False

@classmethod
def from_cli(
Expand Down Expand Up @@ -210,8 +212,7 @@ def from_cli(
"--wait-tasks-timeout",
type=float,
default=None,
help="Maximum time to wait for all current tasks "
"to finish before exiting.",
help="Maximum time to wait for all current tasks to finish before exiting.",
)
parser.add_argument(
"--hardkill-count",
Expand All @@ -220,6 +221,19 @@ def from_cli(
help="Number of termination signals to the main "
"process before performing a hardkill.",
)
parser.add_argument(
"--use-process-pool",
action="store_true",
dest="use_process_pool",
help="Use process pool instead of thread pool for sync tasks.",
)
parser.add_argument(
"--max-process-pool-processes",
type=int,
dest="max_process_pool_processes",
default=None,
help="Maximum number of processes in process pool.",
)

namespace = parser.parse_args(args)
# If there are any patterns specified, remove default.
Expand Down
4 changes: 2 additions & 2 deletions taskiq/cli/worker/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def handle(
target=worker_func,
kwargs={"args": args},
name=f"worker-{self.worker_num}",
daemon=True,
daemon=False,
)
new_process.start()
logger.info(f"Process {new_process.name} restarted with pid {new_process.pid}")
Expand Down Expand Up @@ -193,7 +193,7 @@ def prepare_workers(self) -> None:
target=self.worker_function,
kwargs={"args": self.args},
name=f"worker-{process}",
daemon=True,
daemon=False,
)
work_proc.start()
logger.info(
Expand Down
10 changes: 8 additions & 2 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import signal
import sys
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor
from multiprocessing import set_start_method
from sys import platform
from typing import Any, Optional, Type
Expand Down Expand Up @@ -135,9 +135,15 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
receiver_type = get_receiver_type(args)
receiver_kwargs = dict(args.receiver_arg)

executor: Executor
if args.use_process_pool:
executor = ProcessPoolExecutor(max_workers=args.max_process_pool_processes)
else:
executor = ThreadPoolExecutor(max_workers=args.max_threadpool_threads)

try:
logger.debug("Initialize receiver.")
with ThreadPoolExecutor(args.max_threadpool_threads) as pool:
with executor as pool:
receiver = receiver_type(
broker=broker,
executor=pool,
Expand Down
32 changes: 32 additions & 0 deletions taskiq/decor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from collections.abc import Coroutine
from datetime import datetime
from types import CoroutineType
Expand Down Expand Up @@ -56,6 +57,37 @@ def __init__(
self.original_func = original_func
self.labels = labels

# This is a hack to make ProcessPoolExecutor work
# with decorated functions.
#
# The problem is that when we decorate a function
# it becomes a new class. This class has the same
# name as the original function.
#
# When receiver sends original function to another
# process, it will have the same name as the decorated
# class. This will cause an error, because ProcessPoolExecutor
# uses `__name__` and `__qualname__` attributes to
# import functions from other processes and then it verifies
# that the function is the same as the original one.
#
# This hack renames the original function and injects
# it back to the module where it was defined.
# This way ProcessPoolExecutor will be able to import
# the function by it's name and verify its correctness.
new_name = f"{original_func.__name__}__taskiq_original"
self.original_func.__name__ = new_name
if hasattr(self.original_func, "__qualname__"):
original_qualname = self.original_func.__qualname__.rsplit(".")
original_qualname[-1] = new_name
new_qualname = ".".join(original_qualname)
self.original_func.__qualname__ = new_qualname
setattr(
sys.modules[original_func.__module__],
new_name,
original_func,
)

# Docs for this method are omitted in order to help
# your IDE resolve correct docs for it.
def __call__( # noqa: D102
Expand Down