Skip to content

Commit fc93789

Browse files
authored
Added ability to add custom first-level dependencies (#70)
1 parent 2512c58 commit fc93789

File tree

5 files changed

+67
-2
lines changed

5 files changed

+67
-2
lines changed

docs/guide/state-and-deps.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,13 @@ By default taskiq has only two deendencies:
171171

172172
- Context from `taskiq.context.Context`
173173
- TaskiqState from `taskiq.state.TaskiqState`
174+
175+
176+
### Adding first-level dependencies
177+
178+
You can expand default list of available dependencies for you application.
179+
Taskiq have an ability to add new first-level dependencies using brokers.
180+
181+
The AsyncBroker interface has a function called `add_dependency_context` and you can add
182+
more default dependencies to the taskiq. This may be useful for libraries if you want to
183+
add new dependencies to users.

taskiq/abc/broker.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,23 @@ def __init__(
8686
# Every event has a list of handlers.
8787
# Every handler is a function which takes state as a first argument.
8888
# And handler can be either sync or async.
89-
self.event_handlers: DefaultDict[ # noqa: WPS234
89+
self.event_handlers: DefaultDict[
9090
TaskiqEvents,
9191
List[Callable[[TaskiqState], Optional[Awaitable[None]]]],
9292
] = defaultdict(list)
9393
self.state = TaskiqState()
94+
self.custom_dependency_context: Dict[Any, Any] = {}
95+
96+
def add_dependency_context(self, new_ctx: Dict[Any, Any]) -> None:
97+
"""
98+
Add first-level dependencies.
99+
100+
Provided dict will be used to inject new dependencies
101+
in all dependency graph contexts.
102+
103+
:param new_ctx: Additional context values for dependnecy injection.
104+
"""
105+
self.custom_dependency_context.update(new_ctx)
94106

95107
def add_middlewares(self, *middlewares: "TaskiqMiddleware") -> None:
96108
"""

taskiq/cli/worker/receiver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,14 @@ async def run_task( # noqa: C901, WPS210
159159
dep_ctx = None
160160
if dependency_graph:
161161
# Create a context for dependency resolving.
162-
dep_ctx = dependency_graph.async_ctx(
162+
broker_ctx = self.broker.custom_dependency_context
163+
broker_ctx.update(
163164
{
164165
Context: Context(message, self.broker),
165166
TaskiqState: self.broker.state,
166167
},
167168
)
169+
dep_ctx = dependency_graph.async_ctx(broker_ctx)
168170
# Resolve all function's dependencies.
169171
dep_kwargs = await dep_ctx.resolve_kwargs()
170172
for key, val in dep_kwargs.items():

tests/cli/worker/test_custom_contexts.py

Whitespace-only changes.

tests/cli/worker/test_receiver.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Optional
22

33
import pytest
4+
from taskiq_dependencies import Depends
45

56
from taskiq.abc.broker import AsyncBroker
67
from taskiq.abc.middleware import TaskiqMiddleware
@@ -200,3 +201,43 @@ async def test_callback_unknown_task() -> None:
200201
)
201202

202203
await receiver.callback(broker_message)
204+
205+
206+
@pytest.mark.anyio
207+
async def test_custom_ctx() -> None:
208+
"""Tests that run_task can run sync tasks."""
209+
210+
class MyTestClass:
211+
"""Class to test injection."""
212+
213+
def __init__(self, val: int) -> None:
214+
self.val = val
215+
216+
broker = InMemoryBroker()
217+
218+
# We register a task into broker,
219+
# to build dependency graph on startup.
220+
@broker.task
221+
def test_func(tes_val: MyTestClass = Depends()) -> int:
222+
return tes_val.val
223+
224+
# We add custom first-level dependency.
225+
broker.add_dependency_context({MyTestClass: MyTestClass(11)})
226+
# Create a receiver.
227+
receiver = get_receiver(broker)
228+
229+
result = await receiver.run_task(
230+
test_func,
231+
TaskiqMessage(
232+
task_id="",
233+
task_name=test_func.task_name,
234+
labels={},
235+
args=[],
236+
kwargs={},
237+
),
238+
)
239+
240+
# Check that the value is equal
241+
# to the one we supplied.
242+
assert result.return_value == 11
243+
assert not result.is_err

0 commit comments

Comments
 (0)