Skip to content

Commit 091e703

Browse files
authored
Separation of the receiver from the CLI (#90)
1 parent a1daa65 commit 091e703

File tree

9 files changed

+76
-89
lines changed

9 files changed

+76
-89
lines changed

taskiq/brokers/inmemory_broker.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import asyncio
22
import inspect
33
from collections import OrderedDict
4+
from concurrent.futures import ThreadPoolExecutor
45
from typing import Any, AsyncGenerator, Callable, Optional, Set, TypeVar, get_type_hints
56

67
from taskiq_dependencies import DependencyGraph
78

89
from taskiq.abc.broker import AsyncBroker
910
from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult
10-
from taskiq.cli.worker.args import WorkerArgs
11-
from taskiq.cli.worker.receiver import Receiver
1211
from taskiq.events import TaskiqEvents
1312
from taskiq.exceptions import TaskiqError
1413
from taskiq.message import BrokerMessage
14+
from taskiq.receiver import Receiver
1515
from taskiq.utils import maybe_awaitable
1616

1717
_ReturnType = TypeVar("_ReturnType")
@@ -91,11 +91,11 @@ class InMemoryBroker(AsyncBroker):
9191
def __init__( # noqa: WPS211
9292
self,
9393
sync_tasks_pool_size: int = 4,
94-
logs_format: Optional[str] = None,
9594
max_stored_results: int = 100,
9695
cast_types: bool = True,
9796
result_backend: Optional[AsyncResultBackend[Any]] = None,
9897
task_id_generator: Optional[Callable[[], str]] = None,
98+
max_async_tasks: int = 30,
9999
) -> None:
100100
if result_backend is None:
101101
result_backend = InmemoryResultBackend(
@@ -105,15 +105,12 @@ def __init__( # noqa: WPS211
105105
result_backend=result_backend,
106106
task_id_generator=task_id_generator,
107107
)
108+
self.executor = ThreadPoolExecutor(sync_tasks_pool_size)
108109
self.receiver = Receiver(
109-
self,
110-
WorkerArgs(
111-
broker="",
112-
modules=[],
113-
max_threadpool_threads=sync_tasks_pool_size,
114-
no_parse=not cast_types,
115-
log_collector_format=logs_format or WorkerArgs.log_collector_format,
116-
),
110+
broker=self,
111+
executor=self.executor,
112+
validate_params=cast_types,
113+
max_async_tasks=max_async_tasks,
117114
)
118115
self._running_tasks: "Set[asyncio.Task[Any]]" = set()
119116

@@ -170,3 +167,4 @@ async def shutdown(self) -> None:
170167
for event in (TaskiqEvents.CLIENT_SHUTDOWN, TaskiqEvents.WORKER_SHUTDOWN):
171168
for handler in self.event_handlers.get(event, []):
172169
await maybe_awaitable(handler(self.state))
170+
self.executor.shutdown()

taskiq/cli/worker/args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class WorkerArgs:
2323
shutdown_timeout: float = 5
2424
reload: bool = False
2525
no_gitignore: bool = False
26-
max_async_tasks: int = 10
26+
max_async_tasks: int = 100
2727

2828
@classmethod
2929
def from_cli( # noqa: WPS213
@@ -128,7 +128,7 @@ def from_cli( # noqa: WPS213
128128
"--max-async-tasks",
129129
type=int,
130130
dest="max_async_tasks",
131-
default=10,
131+
default=100,
132132
help="Maximum simultaneous async tasks per worker process. ",
133133
)
134134

taskiq/cli/worker/async_task_runner.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

taskiq/cli/worker/run.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import asyncio
22
import logging
33
import signal
4+
from concurrent.futures import ThreadPoolExecutor
45
from typing import Any
56

67
from watchdog.observers import Observer
78

89
from taskiq.abc.broker import AsyncBroker
910
from taskiq.cli.utils import import_object, import_tasks
1011
from taskiq.cli.worker.args import WorkerArgs
11-
from taskiq.cli.worker.async_task_runner import async_listen_messages
1212
from taskiq.cli.worker.process_manager import ProcessManager
13+
from taskiq.receiver import Receiver
1314

1415
try:
1516
import uvloop # noqa: WPS433
@@ -102,7 +103,15 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
102103

103104
loop = asyncio.get_event_loop()
104105
try:
105-
loop.run_until_complete(async_listen_messages(broker, args))
106+
logger.debug("Initialize receiver.")
107+
with ThreadPoolExecutor(args.max_threadpool_threads) as pool:
108+
receiver = Receiver(
109+
broker=broker,
110+
executor=pool,
111+
validate_params=not args.no_parse,
112+
max_async_tasks=args.max_async_tasks,
113+
)
114+
loop.run_until_complete(receiver.listen())
106115
except KeyboardInterrupt:
107116
logger.warning("Worker process interrupted.")
108117
loop.run_until_complete(shutdown_broker(broker, args.shutdown_timeout))

taskiq/receiver/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Package for message receiver."""
2+
from taskiq.receiver.receiver import Receiver
3+
4+
__all__ = [
5+
"Receiver",
6+
]
File renamed without changes.

taskiq/cli/worker/receiver.py renamed to taskiq/receiver/receiver.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
import asyncio
22
import inspect
3-
from concurrent.futures import ThreadPoolExecutor
3+
from concurrent.futures import Executor
44
from logging import getLogger
55
from time import time
6-
from typing import Any, Callable, Dict, get_type_hints
6+
from typing import Any, Callable, Dict, Optional, get_type_hints
77

88
from taskiq_dependencies import DependencyGraph
99

1010
from taskiq.abc.broker import AsyncBroker
1111
from taskiq.abc.middleware import TaskiqMiddleware
12-
from taskiq.cli.worker.args import WorkerArgs
13-
from taskiq.cli.worker.params_parser import parse_params
1412
from taskiq.context import Context
1513
from taskiq.message import BrokerMessage, TaskiqMessage
14+
from taskiq.receiver.params_parser import parse_params
1615
from taskiq.result import TaskiqResult
1716
from taskiq.state import TaskiqState
1817
from taskiq.utils import maybe_awaitable
@@ -37,20 +36,24 @@ def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
3736
class Receiver:
3837
"""Class that uses as a callback handler."""
3938

40-
def __init__(self, broker: AsyncBroker, cli_args: WorkerArgs) -> None:
39+
def __init__(
40+
self,
41+
broker: AsyncBroker,
42+
executor: Optional[Executor] = None,
43+
validate_params: bool = True,
44+
max_async_tasks: int = 20,
45+
) -> None:
4146
self.broker = broker
42-
self.cli_args = cli_args
47+
self.executor = executor
48+
self.validate_params = validate_params
4349
self.task_signatures: Dict[str, inspect.Signature] = {}
4450
self.task_hints: Dict[str, Dict[str, Any]] = {}
4551
self.dependency_graphs: Dict[str, DependencyGraph] = {}
4652
for task in self.broker.available_tasks.values():
4753
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
4854
self.task_hints[task.task_name] = get_type_hints(task.original_func)
4955
self.dependency_graphs[task.task_name] = DependencyGraph(task.original_func)
50-
self.executor = ThreadPoolExecutor(
51-
max_workers=cli_args.max_threadpool_threads,
52-
)
53-
self.sem = asyncio.Semaphore(cli_args.max_async_tasks)
56+
self.sem = asyncio.Semaphore(max_async_tasks)
5457

5558
async def callback( # noqa: C901, WPS213
5659
self,
@@ -152,10 +155,10 @@ async def run_task( # noqa: C901, WPS210
152155
loop = asyncio.get_running_loop()
153156
returned = None
154157
found_exception = None
155-
signature = self.task_signatures.get(message.task_name)
158+
signature = None
159+
if self.validate_params:
160+
signature = self.task_signatures.get(message.task_name)
156161
dependency_graph = self.dependency_graphs.get(message.task_name)
157-
if self.cli_args.no_parse:
158-
signature = None
159162
parse_params(signature, self.task_hints.get(message.task_name) or {}, message)
160163

161164
dep_ctx = None
@@ -221,3 +224,25 @@ async def run_task( # noqa: C901, WPS210
221224
)
222225

223226
return result
227+
228+
async def listen(self) -> None: # pragma: no cover
229+
"""
230+
This function iterates over tasks asynchronously.
231+
232+
It uses listen() method of an AsyncBroker
233+
to get new messages from queues.
234+
"""
235+
logger.debug("Runing startup event.")
236+
await self.broker.startup()
237+
logger.info("Listening started.")
238+
tasks = set()
239+
async for message in self.broker.listen():
240+
task = asyncio.create_task(self.callback(message=message, raise_err=False))
241+
tasks.add(task)
242+
243+
# We want the task to remove itself from the set when it's done.
244+
#
245+
# Because python's GC can silently cancel task
246+
# and it considered to be Hisenbug.
247+
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
248+
task.add_done_callback(tasks.discard)

tests/cli/worker/test_parameters_parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import pytest
66
from pydantic import BaseModel
77

8-
from taskiq.cli.worker.params_parser import parse_params
98
from taskiq.message import TaskiqMessage
9+
from taskiq.receiver.params_parser import parse_params
1010

1111

1212
class _TestPydanticClass(BaseModel):

tests/cli/worker/test_receiver.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from concurrent.futures import ThreadPoolExecutor
23
from typing import Any, Optional
34

45
import pytest
@@ -7,16 +8,15 @@
78
from taskiq.abc.broker import AsyncBroker
89
from taskiq.abc.middleware import TaskiqMiddleware
910
from taskiq.brokers.inmemory_broker import InMemoryBroker
10-
from taskiq.cli.worker.args import WorkerArgs
11-
from taskiq.cli.worker.receiver import Receiver
1211
from taskiq.message import BrokerMessage, TaskiqMessage
12+
from taskiq.receiver import Receiver
1313
from taskiq.result import TaskiqResult
1414

1515

1616
def get_receiver(
1717
broker: Optional[AsyncBroker] = None,
1818
no_parse: bool = False,
19-
cli_args: Optional[WorkerArgs] = None,
19+
max_async_tasks: int = 10,
2020
) -> Receiver:
2121
"""
2222
Returns receiver with custom broker and args.
@@ -28,15 +28,11 @@ def get_receiver(
2828
"""
2929
if broker is None:
3030
broker = InMemoryBroker()
31-
if cli_args is None:
32-
cli_args = WorkerArgs(
33-
broker="",
34-
modules=[],
35-
no_parse=no_parse,
36-
)
3731
return Receiver(
3832
broker,
39-
cli_args,
33+
executor=ThreadPoolExecutor(max_workers=10),
34+
validate_params=not no_parse,
35+
max_async_tasks=max_async_tasks,
4036
)
4137

4238

@@ -261,13 +257,7 @@ async def task_sem() -> int:
261257
await asyncio.sleep(1)
262258
return 1
263259

264-
cli_args = WorkerArgs(
265-
broker="",
266-
modules=[],
267-
no_parse=False,
268-
max_async_tasks=3,
269-
)
270-
receiver = get_receiver(broker, cli_args=cli_args)
260+
receiver = get_receiver(broker, max_async_tasks=3)
271261

272262
broker_message = broker.formatter.dumps(
273263
TaskiqMessage(
@@ -279,5 +269,6 @@ async def task_sem() -> int:
279269
),
280270
)
281271
tasks = [asyncio.create_task(receiver.callback(broker_message)) for _ in range(5)]
282-
await asyncio.sleep(0)
272+
await asyncio.sleep(0.3)
283273
assert sem_num == 3
274+
await asyncio.gather(*tasks)

0 commit comments

Comments
 (0)