Skip to content

Commit 4b8f015

Browse files
authored
Added local task registry, improved shared_tasks API. (#203)
1 parent a7de65d commit 4b8f015

File tree

8 files changed

+76
-16
lines changed

8 files changed

+76
-16
lines changed

taskiq/abc/broker.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AsyncGenerator,
1212
Awaitable,
1313
Callable,
14+
ClassVar,
1415
DefaultDict,
1516
Dict,
1617
List,
@@ -68,7 +69,7 @@ class AsyncBroker(ABC):
6869
in async mode.
6970
"""
7071

71-
available_tasks: Dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {}
72+
global_task_registry: ClassVar[Dict[str, AsyncTaskiqDecoratedTask[Any, Any]]] = {}
7273

7374
def __init__(
7475
self,
@@ -98,6 +99,7 @@ def __init__(
9899
self.decorator_class = AsyncTaskiqDecoratedTask
99100
self.formatter: "TaskiqFormatter" = JSONFormatter()
100101
self.id_generator = task_id_generator
102+
self.local_task_registry: Dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {}
101103
# Every event has a list of handlers.
102104
# Every handler is a function which takes state as a first argument.
103105
# And handler can be either sync or async.
@@ -112,6 +114,41 @@ def __init__(
112114
# True only if broker runs in scheduler process.
113115
self.is_scheduler_process: bool = False
114116

117+
def find_task(self, task_name: str) -> Optional[AsyncTaskiqDecoratedTask[Any, Any]]:
118+
"""
119+
Returns task by name.
120+
121+
This method should be used to get task by name.
122+
Instead of accessing `available_tasks` or `local_available_tasks` directly.
123+
124+
It searches task by name in dict of tasks that
125+
were registered for this broker directly.
126+
If it fails, it checks global dict of all available tasks.
127+
128+
:param task_name: name of a task.
129+
:returns: found task or None.
130+
"""
131+
return self.local_task_registry.get(
132+
task_name,
133+
) or self.global_task_registry.get(
134+
task_name,
135+
)
136+
137+
def get_all_tasks(self) -> Dict[str, AsyncTaskiqDecoratedTask[Any, Any]]:
138+
"""
139+
Method to fetch all tasks available in broker.
140+
141+
This method returns all tasks, globally and locally
142+
available in broker. With local tasks having higher priority.
143+
144+
So, if you have two tasks with the same name,
145+
one registered in global registry and one registered
146+
in local registry, then local task will be returned.
147+
148+
:return: dict of all tasks. Keys are task names, values are tasks.
149+
"""
150+
return {**self.global_task_registry, **self.local_task_registry}
151+
115152
def add_dependency_context(self, new_ctx: Dict[Any, Any]) -> None:
116153
"""
117154
Add first-level dependencies.
@@ -291,7 +328,7 @@ def inner(
291328
),
292329
)
293330

294-
self.available_tasks[decorated_task.task_name] = decorated_task
331+
self._register_task(decorated_task.task_name, decorated_task)
295332

296333
return decorated_task
297334

@@ -416,3 +453,19 @@ def with_event_handlers(
416453
"""
417454
self.event_handlers[event].extend(handlers)
418455
return self
456+
457+
def _register_task(
458+
self,
459+
task_name: str,
460+
task: AsyncTaskiqDecoratedTask[Any, Any],
461+
) -> None:
462+
"""
463+
Mehtod is used to register tasks.
464+
465+
By default we register tasks in local task registry.
466+
But this behaviour can be changed in subclasses.
467+
468+
:param task_name: Name of a task.
469+
:param task: Decorated task.
470+
"""
471+
self.local_task_registry[task_name] = task

taskiq/brokers/inmemory_broker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def kick(self, message: BrokerMessage) -> None:
120120
121121
:raises TaskiqError: if someone wants to kick unknown task.
122122
"""
123-
target_task = self.available_tasks.get(message.task_name)
123+
target_task = self.find_task(message.task_name)
124124
if target_task is None:
125125
raise TaskiqError("Unknown task.")
126126

taskiq/brokers/shared_broker.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import AsyncGenerator, Optional, TypeVar
1+
from typing import Any, AsyncGenerator, Optional, TypeVar
22

33
from typing_extensions import ParamSpec
44

@@ -71,5 +71,13 @@ async def listen(self) -> AsyncGenerator[bytes, None]: # type: ignore
7171
"""
7272
raise TaskiqError("Shared broker cannot listen")
7373

74+
def _register_task(
75+
self,
76+
task_name: str,
77+
task: AsyncTaskiqDecoratedTask[Any, Any],
78+
) -> None:
79+
self.global_task_registry[task_name] = task
80+
7481

7582
async_shared_broker = AsyncSharedBroker()
83+
shared_task = async_shared_broker.task

taskiq/cli/worker/run.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ def start_listen(args: WorkerArgs, event: Event) -> None: # noqa: WPS210, WPS21
7373
This function starts actual listening process.
7474
7575
It imports broker and all tasks.
76-
Since tasks registers themselves in a global set,
77-
it's easy to just import module where you have decorated
78-
function and they will be available in broker's `available_tasks`
79-
field.
76+
Since tasks auto registeres themselves in a broker,
77+
we don't need to do anything else other than importing.
78+
8079
8180
:param args: CLI arguments.
8281
:param event: Event for notification.

taskiq/receiver/receiver.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__( # noqa: WPS211
6060
self.task_hints: Dict[str, Dict[str, Any]] = {}
6161
self.dependency_graphs: Dict[str, DependencyGraph] = {}
6262
self.propagate_exceptions = propagate_exceptions
63-
for task in self.broker.available_tasks.values():
63+
for task in self.broker.get_all_tasks().values():
6464
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
6565
self.task_hints[task.task_name] = get_type_hints(task.original_func)
6666
self.dependency_graphs[task.task_name] = DependencyGraph(task.original_func)
@@ -106,7 +106,8 @@ async def callback( # noqa: C901, WPS213, WPS217
106106
)
107107
return
108108
logger.debug(f"Received message: {taskiq_msg}")
109-
if taskiq_msg.task_name not in self.broker.available_tasks:
109+
task = self.broker.find_task(taskiq_msg.task_name)
110+
if task is None:
110111
logger.warning(
111112
'task "%s" is not found. Maybe you forgot to import it?',
112113
taskiq_msg.task_name,
@@ -135,7 +136,7 @@ async def callback( # noqa: C901, WPS213, WPS217
135136
await maybe_awaitable(message.ack())
136137

137138
result = await self.run_task(
138-
target=self.broker.available_tasks[taskiq_msg.task_name].original_func,
139+
target=task.original_func,
139140
message=taskiq_msg,
140141
)
141142

taskiq/schedule_sources/label_based.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async def get_schedules(self) -> List["ScheduledTask"]:
2727
:return: list of schedules.
2828
"""
2929
schedules = []
30-
for task_name, task in self.broker.available_tasks.items():
30+
for task_name, task in self.broker.get_all_tasks().items():
3131
if task.broker != self.broker:
3232
continue
3333
for schedule in task.labels.get("schedule", []):
@@ -61,7 +61,7 @@ def post_send(self, scheduled_task: ScheduledTask) -> None:
6161
if scheduled_task.cron or not scheduled_task.time:
6262
return # it's scheduled task with cron label, do not remove this trigger.
6363

64-
for task_name, task in self.broker.available_tasks.items():
64+
for task_name, task in self.broker.get_all_tasks().items():
6565
if task.broker != self.broker or scheduled_task.task_name != task_name:
6666
continue
6767

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def reset_broker() -> Generator[None, None, None]:
2727
broker variables to default state.
2828
"""
2929
yield
30-
AsyncBroker.available_tasks = {}
30+
AsyncBroker.global_task_registry = {}
3131
AsyncBroker.is_worker_process = False
3232
AsyncBroker.is_scheduler_process = False
3333

tox.ini

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,4 @@ allowlist_externals = poetry
1212
commands_pre =
1313
poetry install
1414
commands =
15-
pre-commit run --all-files
16-
poetry run pytest -vv
15+
poetry run pytest -vv -n auto

0 commit comments

Comments
 (0)