Skip to content

Commit 28b873f

Browse files
authored
Added dynamic dependency resolution for unknown tasks (#208)
1 parent feb2481 commit 28b873f

File tree

3 files changed

+62
-21
lines changed

3 files changed

+62
-21
lines changed

taskiq/brokers/inmemory_broker.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import asyncio
2-
import inspect
32
from collections import OrderedDict
43
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Any, AsyncGenerator, Set, TypeVar, get_type_hints
6-
7-
from taskiq_dependencies import DependencyGraph
4+
from typing import Any, AsyncGenerator, Set, TypeVar
85

96
from taskiq.abc.broker import AsyncBroker
107
from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult
@@ -124,19 +121,6 @@ async def kick(self, message: BrokerMessage) -> None:
124121
if target_task is None:
125122
raise TaskiqError("Unknown task.")
126123

127-
if not self.receiver.dependency_graphs.get(target_task.task_name):
128-
self.receiver.dependency_graphs[target_task.task_name] = DependencyGraph(
129-
target_task.original_func,
130-
)
131-
if not self.receiver.task_signatures.get(target_task.task_name):
132-
self.receiver.task_signatures[target_task.task_name] = inspect.signature(
133-
target_task.original_func,
134-
)
135-
if not self.receiver.task_hints.get(target_task.task_name):
136-
self.receiver.task_hints[target_task.task_name] = get_type_hints(
137-
target_task.original_func,
138-
)
139-
140124
task = asyncio.create_task(self.receiver.callback(message=message.message))
141125
self._running_tasks.add(task)
142126
task.add_done_callback(self._running_tasks.discard)

taskiq/receiver/receiver.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,9 @@ def __init__( # noqa: WPS211
6464
self.dependency_graphs: Dict[str, DependencyGraph] = {}
6565
self.propagate_exceptions = propagate_exceptions
6666
self.on_exit = on_exit
67+
self.known_tasks: Set[str] = set()
6768
for task in self.broker.get_all_tasks().values():
68-
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
69-
self.task_hints[task.task_name] = get_type_hints(task.original_func)
70-
self.dependency_graphs[task.task_name] = DependencyGraph(task.original_func)
69+
self._prepare_task(task.task_name, task.original_func)
7170
self.sem: "Optional[asyncio.Semaphore]" = None
7271
if max_async_tasks is not None and max_async_tasks > 0:
7372
self.sem = asyncio.Semaphore(max_async_tasks)
@@ -163,7 +162,7 @@ async def callback( # noqa: C901, WPS213, WPS217
163162
if raise_err:
164163
raise exc
165164

166-
async def run_task( # noqa: C901, WPS210
165+
async def run_task( # noqa: C901, WPS210, WPS213
167166
self,
168167
target: Callable[..., Any],
169168
message: TaskiqMessage,
@@ -190,6 +189,8 @@ async def run_task( # noqa: C901, WPS210
190189
returned = None
191190
found_exception: "Optional[BaseException]" = None
192191
signature = None
192+
if message.task_name not in self.known_tasks:
193+
self._prepare_task(message.task_name, target)
193194
if self.validate_params:
194195
signature = self.task_signatures.get(message.task_name)
195196
dependency_graph = self.dependency_graphs.get(message.task_name)
@@ -382,3 +383,23 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
382383
# and this behaviour considered to be a Hisenbug.
383384
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
384385
task.add_done_callback(task_cb)
386+
387+
def _prepare_task(self, name: str, handler: Callable[..., Any]) -> None:
388+
"""
389+
Prepare task for execution.
390+
391+
This function gets function's signature,
392+
type hints and builds dependency graph.
393+
394+
It's useful for dynamic dependency resolution,
395+
because sometimes the receiver can get
396+
funcion that is defined in runtime. We need
397+
to be aware of that.
398+
399+
:param name: task name.
400+
:param handler: task handler.
401+
"""
402+
self.known_tasks.add(name)
403+
self.task_signatures[name] = inspect.signature(handler)
404+
self.task_hints[name] = get_type_hints(handler)
405+
self.dependency_graphs[name] = DependencyGraph(handler)

tests/receiver/test_receiver.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import random
23
import time
34
from concurrent.futures import ThreadPoolExecutor
45
from typing import Any, AsyncGenerator, List, Optional, TypeVar
@@ -227,6 +228,41 @@ async def my_task() -> int:
227228
assert called_times == 1
228229

229230

231+
@pytest.mark.anyio
232+
async def test_callback_no_dep_info() -> None:
233+
"""Test that callback function works well."""
234+
broker = InMemoryBroker()
235+
expected = random.randint(1, 100)
236+
ret_val = None
237+
238+
def dependency() -> int:
239+
return expected
240+
241+
@broker.task
242+
async def my_task(dep: int = Depends(dependency)) -> None:
243+
nonlocal ret_val
244+
ret_val = dep
245+
246+
receiver = get_receiver(broker)
247+
receiver.known_tasks.remove(my_task.task_name)
248+
receiver.dependency_graphs.pop(my_task.task_name, None)
249+
receiver.task_signatures.pop(my_task.task_name, None)
250+
receiver.task_hints.pop(my_task.task_name, None)
251+
252+
broker_message = broker.formatter.dumps(
253+
TaskiqMessage(
254+
task_id="task_id",
255+
task_name=my_task.task_name,
256+
labels={},
257+
args=[],
258+
kwargs={},
259+
),
260+
)
261+
262+
await receiver.callback(broker_message.message)
263+
assert ret_val == expected
264+
265+
230266
@pytest.mark.anyio
231267
async def test_callback_success_ackable() -> None:
232268
"""Test that acking works."""

0 commit comments

Comments
 (0)