Skip to content

Commit 63ab802

Browse files
authored
Added context injector. (#38)
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent 445fd81 commit 63ab802

File tree

4 files changed

+127
-61
lines changed

4 files changed

+127
-61
lines changed

taskiq/brokers/inmemory_broker.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,10 @@ async def kick(self, message: BrokerMessage) -> None:
124124
target_task = self.available_tasks.get(message.task_name)
125125
if target_task is None:
126126
raise TaskiqError("Unknown task.")
127-
if self.receiver.task_signatures:
128-
if not self.receiver.task_signatures.get(target_task.task_name):
129-
self.receiver.task_signatures[
130-
target_task.task_name
131-
] = inspect.signature(
132-
target_task.original_func,
133-
)
127+
if not self.receiver.task_signatures.get(target_task.task_name):
128+
self.receiver.task_signatures[target_task.task_name] = inspect.signature(
129+
target_task.original_func,
130+
)
134131

135132
await self.receiver.callback(message=message)
136133

taskiq/cli/receiver.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,48 @@
44
from concurrent.futures import ThreadPoolExecutor
55
from logging import getLogger
66
from time import time
7-
from typing import Any, Callable, Dict
7+
from typing import Any, Callable, Dict, Optional
88

99
from taskiq.abc.broker import AsyncBroker
1010
from taskiq.abc.middleware import TaskiqMiddleware
1111
from taskiq.cli.args import TaskiqArgs
1212
from taskiq.cli.log_collector import log_collector
1313
from taskiq.cli.params_parser import parse_params
14-
from taskiq.context import Context, context_updater
14+
from taskiq.context import Context
1515
from taskiq.message import BrokerMessage, TaskiqMessage
1616
from taskiq.result import TaskiqResult
1717
from taskiq.utils import maybe_awaitable
1818

1919
logger = getLogger(__name__)
2020

2121

22+
def inject_context(
23+
signature: Optional[inspect.Signature],
24+
message: TaskiqMessage,
25+
broker: AsyncBroker,
26+
) -> None:
27+
"""
28+
Inject context parameter in message's kwargs.
29+
30+
This function parses signature to get
31+
the context parameter definition.
32+
33+
If at least one parameter has the Context
34+
type, it will add current context as kwarg.
35+
36+
:param signature: function's signature.
37+
:param message: current taskiq message.
38+
:param broker: current broker.
39+
"""
40+
if signature is None:
41+
return
42+
for param_name, param in signature.parameters.items():
43+
if param.annotation is param.empty:
44+
continue
45+
if param.annotation is Context:
46+
message.kwargs[param_name] = Context(message.copy(), broker)
47+
48+
2249
def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
2350
"""
2451
Runs function synchronously.
@@ -40,11 +67,8 @@ def __init__(self, broker: AsyncBroker, cli_args: TaskiqArgs) -> None:
4067
self.broker = broker
4168
self.cli_args = cli_args
4269
self.task_signatures: Dict[str, inspect.Signature] = {}
43-
if not cli_args.no_parse:
44-
for task in self.broker.available_tasks.values():
45-
self.task_signatures[task.task_name] = inspect.signature(
46-
task.original_func,
47-
)
70+
for task in self.broker.available_tasks.values():
71+
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
4872
self.executor = ThreadPoolExecutor(
4973
max_workers=cli_args.max_threadpool_threads,
5074
)
@@ -100,11 +124,10 @@ async def callback( # noqa: C901
100124
taskiq_msg.task_name,
101125
taskiq_msg.task_id,
102126
)
103-
with context_updater(Context(taskiq_msg, self.broker)):
104-
result = await self.run_task(
105-
target=self.broker.available_tasks[message.task_name].original_func,
106-
message=taskiq_msg,
107-
)
127+
result = await self.run_task(
128+
target=self.broker.available_tasks[message.task_name].original_func,
129+
message=taskiq_msg,
130+
)
108131
for middleware in self.broker.middlewares:
109132
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
110133
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
@@ -147,7 +170,15 @@ async def run_task( # noqa: C901, WPS210
147170
logs = io.StringIO()
148171
returned = None
149172
found_exception = None
150-
parse_params(self.task_signatures.get(message.task_name), message)
173+
signature = self.task_signatures.get(message.task_name)
174+
if self.cli_args.no_parse:
175+
signature = None
176+
parse_params(signature, message)
177+
inject_context(
178+
self.task_signatures.get(message.task_name),
179+
message,
180+
self.broker,
181+
)
151182
# Captures function's logs.
152183
with log_collector(logs, self.cli_args.log_collector_format):
153184
# Start a timer.

taskiq/cli/tests/test_context.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import inspect
2+
3+
from taskiq.cli.receiver import inject_context
4+
from taskiq.context import Context
5+
from taskiq.message import TaskiqMessage
6+
7+
8+
def test_inject_context_success() -> None:
9+
"""Test that context variable is injected as expected."""
10+
11+
def func(param1: int, ctx: Context) -> int:
12+
return param1
13+
14+
message = TaskiqMessage(
15+
task_id="",
16+
task_name="",
17+
labels={},
18+
args=[1],
19+
kwargs={},
20+
)
21+
22+
inject_context(
23+
inspect.signature(func),
24+
message=message,
25+
broker=None, # type: ignore
26+
)
27+
28+
assert message.kwargs.get("ctx")
29+
assert isinstance(message.kwargs["ctx"], Context)
30+
31+
32+
def test_inject_context_no_typehint() -> None:
33+
"""Test that context won't be injected in untyped parameter."""
34+
35+
def func(param1: int, ctx) -> int: # type: ignore
36+
return param1
37+
38+
message = TaskiqMessage(
39+
task_id="",
40+
task_name="",
41+
labels={},
42+
args=[1],
43+
kwargs={},
44+
)
45+
46+
inject_context(
47+
inspect.signature(func),
48+
message=message,
49+
broker=None, # type: ignore
50+
)
51+
52+
assert message.kwargs.get("ctx") is None
53+
54+
55+
def test_inject_context_no_ctx_parameter() -> None:
56+
"""
57+
Tests that injector won't raise an error.
58+
59+
If the Context-typed parameter doesn't exist.
60+
"""
61+
62+
def func(param1: int) -> int:
63+
return param1
64+
65+
message = TaskiqMessage(
66+
task_id="",
67+
task_name="",
68+
labels={},
69+
args=[1],
70+
kwargs={},
71+
)
72+
73+
inject_context(
74+
inspect.signature(func),
75+
message=message,
76+
broker=None, # type: ignore
77+
)
78+
79+
assert not message.kwargs

taskiq/context.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from contextlib import contextmanager
2-
from typing import Generator
3-
41
from taskiq.abc.broker import AsyncBroker
52
from taskiq.message import TaskiqMessage
63

@@ -14,41 +11,3 @@ def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None:
1411

1512

1613
default_context = Context(None, None) # type: ignore
17-
current_context = None
18-
19-
20-
@contextmanager
21-
def context_updater(new_context: Context) -> Generator[None, None, None]:
22-
"""
23-
Update context for some time.
24-
25-
:param new_context: new context to set.
26-
:yield: nothing.
27-
"""
28-
global current_context # noqa: WPS420
29-
current_context = new_context # noqa: WPS442
30-
31-
yield
32-
33-
current_context = None # noqa: WPS442
34-
35-
36-
def get_context() -> Context:
37-
"""
38-
Get current context.
39-
40-
This function always return contexts,
41-
but if you call this function inside tests,
42-
or somewhere you have to be careful,
43-
since if current_context is None it will
44-
return default_context.
45-
46-
To override context please use context_updater
47-
context manager.
48-
49-
:return: context.
50-
"""
51-
global current_context # noqa: WPS420
52-
if current_context is None:
53-
return default_context
54-
return current_context

0 commit comments

Comments
 (0)