Skip to content

Commit 1123d00

Browse files
committed
Merge branch 'release/0.4.3'
2 parents f54d1ef + aca7228 commit 1123d00

File tree

5 files changed

+71
-30
lines changed

5 files changed

+71
-30
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "taskiq"
3-
version = "0.4.2"
3+
version = "0.4.3"
44
description = "Distributed task queue with full async support"
55
authors = ["Pavel Kirilin <[email protected]>"]
66
maintainers = ["Pavel Kirilin <[email protected]>"]

taskiq/brokers/inmemory_broker.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
from collections import OrderedDict
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Any, AsyncGenerator, Callable, Optional, Set, TypeVar, get_type_hints
5+
from typing import Any, AsyncGenerator, Set, TypeVar, get_type_hints
66

77
from taskiq_dependencies import DependencyGraph
88

@@ -88,22 +88,16 @@ class InMemoryBroker(AsyncBroker):
8888
It's useful for local development, if you don't want to setup real broker.
8989
"""
9090

91-
def __init__( # noqa: WPS211
91+
def __init__(
9292
self,
9393
sync_tasks_pool_size: int = 4,
9494
max_stored_results: int = 100,
9595
cast_types: bool = True,
96-
result_backend: Optional[AsyncResultBackend[Any]] = None,
97-
task_id_generator: Optional[Callable[[], str]] = None,
9896
max_async_tasks: int = 30,
9997
) -> None:
100-
if result_backend is None:
101-
result_backend = InmemoryResultBackend(
102-
max_stored_results=max_stored_results,
103-
)
104-
super().__init__(
105-
result_backend=result_backend,
106-
task_id_generator=task_id_generator,
98+
super().__init__()
99+
self.result_backend = InmemoryResultBackend(
100+
max_stored_results=max_stored_results,
107101
)
108102
self.executor = ThreadPoolExecutor(sync_tasks_pool_size)
109103
self.receiver = Receiver(

taskiq/cli/worker/args.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
2-
from dataclasses import dataclass
3-
from typing import List, Optional, Sequence
2+
from dataclasses import dataclass, field
3+
from typing import List, Optional, Sequence, Tuple
44

55
from taskiq.cli.common_args import LogLevel
66

77

8+
def receiver_arg_type(string: str) -> Tuple[str, str]:
9+
"""
10+
Parse cli --receiver_arg argument value.
11+
12+
:param string: cli argument value in format key=value.
13+
:raises ValueError: if value not in format.
14+
:return: (key, value) pair.
15+
"""
16+
args = string.split("=", 1)
17+
if len(args) != 2:
18+
raise ValueError(f"Invalid value: {string}")
19+
return args[0], args[1]
20+
21+
822
@dataclass
923
class WorkerArgs:
1024
"""Taskiq worker CLI arguments."""
@@ -24,6 +38,8 @@ class WorkerArgs:
2438
reload: bool = False
2539
no_gitignore: bool = False
2640
max_async_tasks: int = 100
41+
receiver: str = "taskiq.receiver:Receiver"
42+
receiver_arg: List[Tuple[str, str]] = field(default_factory=list)
2743

2844
@classmethod
2945
def from_cli( # noqa: WPS213
@@ -45,6 +61,26 @@ def from_cli( # noqa: WPS213
4561
"'module.module:variable' format."
4662
),
4763
)
64+
parser.add_argument(
65+
"--receiver",
66+
default="taskiq.receiver:Receiver",
67+
help=(
68+
"Where to search for receiver. "
69+
"This string must be specified in "
70+
"'module.module:variable' format."
71+
),
72+
)
73+
parser.add_argument(
74+
"--receiver_arg",
75+
action="append",
76+
type=receiver_arg_type,
77+
default=[],
78+
help=(
79+
"List of args fot receiver. "
80+
"This string must be specified in "
81+
"`key=value` format."
82+
),
83+
)
4884
parser.add_argument(
4985
"--tasks-pattern",
5086
"-tp",

taskiq/cli/worker/run.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import signal
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Any
5+
from typing import Any, Type
66

77
from taskiq.abc.broker import AsyncBroker
88
from taskiq.cli.utils import import_object, import_tasks
@@ -51,7 +51,21 @@ async def shutdown_broker(broker: AsyncBroker, timeout: float) -> None:
5151
)
5252

5353

54-
def start_listen(args: WorkerArgs) -> None: # noqa: WPS213
54+
def get_receiver_type(args: WorkerArgs) -> Type[Receiver]:
55+
"""
56+
Import Receiver from args.
57+
58+
:param args: CLI arguments.
59+
:raises ValueError: if receiver is not a Receiver type.
60+
:return: Receiver type.
61+
"""
62+
receiver_type = import_object(args.receiver)
63+
if not (isinstance(receiver_type, type) and issubclass(receiver_type, Receiver)):
64+
raise ValueError("Unknown receiver type. Please use Receiver class.")
65+
return receiver_type
66+
67+
68+
def start_listen(args: WorkerArgs) -> None: # noqa: WPS210, WPS213
5569
"""
5670
This function starts actual listening process.
5771
@@ -63,6 +77,7 @@ def start_listen(args: WorkerArgs) -> None: # noqa: WPS213
6377
6478
:param args: CLI arguments.
6579
:raises ValueError: if broker is not an AsyncBroker instance.
80+
:raises ValueError: if receiver is not a Receiver type.
6681
"""
6782
if uvloop is not None:
6883
logger.debug("UVLOOP found. Installing policy.")
@@ -77,6 +92,9 @@ def start_listen(args: WorkerArgs) -> None: # noqa: WPS213
7792
if not isinstance(broker, AsyncBroker):
7893
raise ValueError("Unknown broker type. Please use AsyncBroker instance.")
7994

95+
receiver_type = get_receiver_type(args)
96+
receiver_args = dict(args.receiver_arg)
97+
8098
# Here how we manage interruptions.
8199
# We have to remember shutting_down state,
82100
# because KeyboardInterrupt can be send multiple
@@ -105,14 +123,16 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
105123
signal.signal(signal.SIGTERM, interrupt_handler)
106124

107125
loop = asyncio.get_event_loop()
126+
108127
try:
109128
logger.debug("Initialize receiver.")
110129
with ThreadPoolExecutor(args.max_threadpool_threads) as pool:
111-
receiver = Receiver(
130+
receiver = receiver_type(
112131
broker=broker,
113132
executor=pool,
114133
validate_params=not args.no_parse,
115134
max_async_tasks=args.max_async_tasks,
135+
**receiver_args,
116136
)
117137
loop.run_until_complete(receiver.listen())
118138
except KeyboardInterrupt:

tests/cli/worker/test_receiver.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import asyncio
22
from concurrent.futures import ThreadPoolExecutor
3-
from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar
3+
from typing import Any, AsyncGenerator, List, Optional, TypeVar
44

55
import pytest
66
from taskiq_dependencies import Depends
77

88
from taskiq.abc.broker import AsyncBroker
99
from taskiq.abc.middleware import TaskiqMiddleware
10-
from taskiq.abc.result_backend import AsyncResultBackend
1110
from taskiq.brokers.inmemory_broker import InMemoryBroker
1211
from taskiq.message import TaskiqMessage
1312
from taskiq.receiver import Receiver
@@ -17,15 +16,8 @@
1716

1817

1918
class BrokerForTests(InMemoryBroker):
20-
def __init__(
21-
self,
22-
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
23-
task_id_generator: Optional[Callable[[], str]] = None,
24-
) -> None:
25-
super().__init__(
26-
result_backend=result_backend,
27-
task_id_generator=task_id_generator,
28-
)
19+
def __init__(self) -> None:
20+
super().__init__()
2921
self.to_send: "List[TaskiqMessage]" = []
3022

3123
async def listen(self) -> AsyncGenerator[bytes, None]:
@@ -142,8 +134,7 @@ def on_error(
142134
def test_func() -> None:
143135
raise ValueError()
144136

145-
broker = InMemoryBroker()
146-
broker.add_middlewares(_TestMiddleware())
137+
broker = InMemoryBroker().with_middlewares(_TestMiddleware())
147138
receiver = get_receiver(broker)
148139

149140
result = await receiver.run_task(

0 commit comments

Comments
 (0)