Skip to content

Commit 181ff66

Browse files
authored
Added option to run in processpool. (#428)
* Added option to run in processpool. * Workers aren't daemons now. * Fixed process sending eror. * Fixed docs and qualname generation.
1 parent c0374eb commit 181ff66

File tree

5 files changed

+70
-10
lines changed

5 files changed

+70
-10
lines changed

taskiq/api/receiver.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from concurrent.futures import ThreadPoolExecutor
2+
from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor
33
from logging import getLogger
44
from typing import Optional, Type
55

@@ -13,13 +13,14 @@
1313
async def run_receiver_task(
1414
broker: AsyncBroker,
1515
receiver_cls: Type[Receiver] = Receiver,
16-
sync_workers: int = 4,
16+
sync_workers: Optional[int] = None,
1717
validate_params: bool = True,
1818
max_async_tasks: int = 100,
1919
max_prefetch: int = 0,
2020
propagate_exceptions: bool = True,
2121
run_startup: bool = False,
2222
ack_time: Optional[AcknowledgeType] = None,
23+
use_process_pool: bool = False,
2324
) -> None:
2425
"""
2526
Function to run receiver programmatically.
@@ -39,13 +40,15 @@ async def run_receiver_task(
3940
4041
:param broker: current broker instance.
4142
:param receiver_cls: receiver class to use.
42-
:param sync_workers: number of threads of a threadpool that runs sync tasks.
43+
:param sync_workers: number of threads of a threadpool
44+
or processes in processpool that runs sync tasks.
4345
:param validate_params: whether to validate params or not.
4446
:param max_async_tasks: maximum number of simultaneous async tasks.
4547
:param max_prefetch: maximum number of tasks to prefetch.
4648
:param propagate_exceptions: whether to propagate exceptions in generators or not.
4749
:param run_startup: whether to run startup function or not.
4850
:param ack_time: acknowledge type to use.
51+
:param use_process_pool: whether to use process pool or threadpool.
4952
:raises asyncio.CancelledError: if the task was cancelled.
5053
"""
5154
finish_event = asyncio.Event()
@@ -62,7 +65,12 @@ def on_exit(_: Receiver) -> None:
6265
finish_event.set()
6366
raise asyncio.CancelledError
6467

65-
with ThreadPoolExecutor(max_workers=sync_workers) as executor:
68+
executor: Executor
69+
if use_process_pool:
70+
executor = ProcessPoolExecutor(max_workers=sync_workers)
71+
else:
72+
executor = ThreadPoolExecutor(max_workers=sync_workers)
73+
with executor as executor:
6674
broker.is_worker_process = True
6775
while True:
6876
try:

taskiq/cli/worker/args.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class WorkerArgs:
3232
log_level: LogLevel = LogLevel.INFO
3333
workers: int = 2
3434
max_threadpool_threads: int = 10
35+
max_process_pool_processes: Optional[int] = None
3536
no_parse: bool = False
3637
shutdown_timeout: float = 5
3738
reload: bool = False
@@ -46,6 +47,7 @@ class WorkerArgs:
4647
max_tasks_per_child: Optional[int] = None
4748
wait_tasks_timeout: Optional[float] = None
4849
hardkill_count: int = 3
50+
use_process_pool: bool = False
4951

5052
@classmethod
5153
def from_cli(
@@ -210,8 +212,7 @@ def from_cli(
210212
"--wait-tasks-timeout",
211213
type=float,
212214
default=None,
213-
help="Maximum time to wait for all current tasks "
214-
"to finish before exiting.",
215+
help="Maximum time to wait for all current tasks to finish before exiting.",
215216
)
216217
parser.add_argument(
217218
"--hardkill-count",
@@ -220,6 +221,19 @@ def from_cli(
220221
help="Number of termination signals to the main "
221222
"process before performing a hardkill.",
222223
)
224+
parser.add_argument(
225+
"--use-process-pool",
226+
action="store_true",
227+
dest="use_process_pool",
228+
help="Use process pool instead of thread pool for sync tasks.",
229+
)
230+
parser.add_argument(
231+
"--max-process-pool-processes",
232+
type=int,
233+
dest="max_process_pool_processes",
234+
default=None,
235+
help="Maximum number of processes in process pool.",
236+
)
223237

224238
namespace = parser.parse_args(args)
225239
# If there are any patterns specified, remove default.

taskiq/cli/worker/process_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def handle(
8383
target=worker_func,
8484
kwargs={"args": args},
8585
name=f"worker-{self.worker_num}",
86-
daemon=True,
86+
daemon=False,
8787
)
8888
new_process.start()
8989
logger.info(f"Process {new_process.name} restarted with pid {new_process.pid}")
@@ -193,7 +193,7 @@ def prepare_workers(self) -> None:
193193
target=self.worker_function,
194194
kwargs={"args": self.args},
195195
name=f"worker-{process}",
196-
daemon=True,
196+
daemon=False,
197197
)
198198
work_proc.start()
199199
logger.info(

taskiq/cli/worker/run.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import signal
66
import sys
7-
from concurrent.futures import ThreadPoolExecutor
7+
from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor
88
from multiprocessing import set_start_method
99
from sys import platform
1010
from typing import Any, Optional, Type
@@ -143,9 +143,15 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
143143
receiver_type = get_receiver_type(args)
144144
receiver_kwargs = dict(args.receiver_arg)
145145

146+
executor: Executor
147+
if args.use_process_pool:
148+
executor = ProcessPoolExecutor(max_workers=args.max_process_pool_processes)
149+
else:
150+
executor = ThreadPoolExecutor(max_workers=args.max_threadpool_threads)
151+
146152
try:
147153
logger.debug("Initialize receiver.")
148-
with ThreadPoolExecutor(args.max_threadpool_threads) as pool:
154+
with executor as pool:
149155
receiver = receiver_type(
150156
broker=broker,
151157
executor=pool,

taskiq/decor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from collections.abc import Coroutine
23
from datetime import datetime
34
from types import CoroutineType
@@ -56,6 +57,37 @@ def __init__(
5657
self.original_func = original_func
5758
self.labels = labels
5859

60+
# This is a hack to make ProcessPoolExecutor work
61+
# with decorated functions.
62+
#
63+
# The problem is that when we decorate a function
64+
# it becomes a new class. This class has the same
65+
# name as the original function.
66+
#
67+
# When receiver sends original function to another
68+
# process, it will have the same name as the decorated
69+
# class. This will cause an error, because ProcessPoolExecutor
70+
# uses `__name__` and `__qualname__` attributes to
71+
# import functions from other processes and then it verifies
72+
# that the function is the same as the original one.
73+
#
74+
# This hack renames the original function and injects
75+
# it back to the module where it was defined.
76+
# This way ProcessPoolExecutor will be able to import
77+
# the function by it's name and verify its correctness.
78+
new_name = f"{original_func.__name__}__taskiq_original"
79+
self.original_func.__name__ = new_name
80+
if hasattr(self.original_func, "__qualname__"):
81+
original_qualname = self.original_func.__qualname__.rsplit(".")
82+
original_qualname[-1] = new_name
83+
new_qualname = ".".join(original_qualname)
84+
self.original_func.__qualname__ = new_qualname
85+
setattr(
86+
sys.modules[original_func.__module__],
87+
new_name,
88+
original_func,
89+
)
90+
5991
# Docs for this method are omitted in order to help
6092
# your IDE resolve correct docs for it.
6193
def __call__( # noqa: D102

0 commit comments

Comments
 (0)