Skip to content

Commit 0e8b7c6

Browse files
authored
Fixed parameter parsing for string annotations. (#39)
* Fixed parameter parsing for string annotations. Signed-off-by: Pavel Kirilin <[email protected]>
1 parent 63ab802 commit 0e8b7c6

File tree

5 files changed

+88
-28
lines changed

5 files changed

+88
-28
lines changed

taskiq/brokers/inmemory_broker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
from collections import OrderedDict
3-
from typing import Any, Callable, Coroutine, Optional, TypeVar
3+
from typing import Any, Callable, Coroutine, Optional, TypeVar, get_type_hints
44

55
from taskiq.abc.broker import AsyncBroker
66
from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult
@@ -128,6 +128,10 @@ async def kick(self, message: BrokerMessage) -> None:
128128
self.receiver.task_signatures[target_task.task_name] = inspect.signature(
129129
target_task.original_func,
130130
)
131+
if not self.receiver.task_hints.get(target_task.task_name):
132+
self.receiver.task_hints[target_task.task_name] = get_type_hints(
133+
target_task.original_func,
134+
)
131135

132136
await self.receiver.callback(message=message)
133137

taskiq/cli/params_parser.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
from logging import getLogger
3-
from typing import Optional
3+
from typing import Any, Dict, Optional
44

55
from pydantic import parse_obj_as
66

@@ -11,6 +11,7 @@
1111

1212
def parse_params( # noqa: C901
1313
signature: Optional[inspect.Signature],
14+
type_hints: Dict[str, Any],
1415
message: TaskiqMessage,
1516
) -> None:
1617
"""
@@ -42,25 +43,30 @@ def parse_params( # noqa: C901
4243
or you can make some of parameters untyped,
4344
or use Any.
4445
46+
Why do we need type_hints separate with
47+
Signature. The reason is simple.
48+
If some variable doesn't have a type hint
49+
it won't be added in the dict of type hints.
50+
4551
:param signature: original function's signature.
52+
:param type_hints: function's type hints.
4653
:param message: incoming message.
4754
"""
4855
if signature is None:
4956
return
5057
argnum = -1
5158
# Iterate over function's params.
52-
for param_name, params_type in signature.parameters.items():
59+
for param_name in signature.parameters.keys():
5360
# If parameter doesn't have an annotation.
54-
if params_type.annotation is params_type.empty:
61+
annot = type_hints.get(param_name)
62+
if annot is None:
5563
continue
5664
# Increment argument numbers. This is
5765
# for positional arguments.
5866
argnum += 1
59-
# Shortland for params_type.annotation
60-
annot = params_type.annotation
6167
# Value from incoming message.
6268
value = None
63-
logger.debug("Trying to parse %s as %s", param_name, params_type.annotation)
69+
logger.debug("Trying to parse %s as %s", param_name, annot)
6470
# Check if we have positional arguments in passed message.
6571
if argnum < len(message.args):
6672
# Get positional argument.

taskiq/cli/receiver.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from concurrent.futures import ThreadPoolExecutor
55
from logging import getLogger
66
from time import time
7-
from typing import Any, Callable, Dict, Optional
7+
from typing import Any, Callable, Dict, get_type_hints
88

99
from taskiq.abc.broker import AsyncBroker
1010
from taskiq.abc.middleware import TaskiqMiddleware
@@ -20,7 +20,7 @@
2020

2121

2222
def inject_context(
23-
signature: Optional[inspect.Signature],
23+
type_hints: Dict[str, Any],
2424
message: TaskiqMessage,
2525
broker: AsyncBroker,
2626
) -> None:
@@ -33,16 +33,14 @@ def inject_context(
3333
If at least one parameter has the Context
3434
type, it will add current context as kwarg.
3535
36-
:param signature: function's signature.
36+
:param type_hints: function's type hints.
3737
:param message: current taskiq message.
3838
:param broker: current broker.
3939
"""
40-
if signature is None:
40+
if not type_hints:
4141
return
42-
for param_name, param in signature.parameters.items():
43-
if param.annotation is param.empty:
44-
continue
45-
if param.annotation is Context:
42+
for param_name, param_type in type_hints.items():
43+
if param_type is Context:
4644
message.kwargs[param_name] = Context(message.copy(), broker)
4745

4846

@@ -67,8 +65,10 @@ def __init__(self, broker: AsyncBroker, cli_args: TaskiqArgs) -> None:
6765
self.broker = broker
6866
self.cli_args = cli_args
6967
self.task_signatures: Dict[str, inspect.Signature] = {}
68+
self.task_hints: Dict[str, Dict[str, Any]] = {}
7069
for task in self.broker.available_tasks.values():
7170
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
71+
self.task_hints[task.task_name] = get_type_hints(task.original_func)
7272
self.executor = ThreadPoolExecutor(
7373
max_workers=cli_args.max_threadpool_threads,
7474
)
@@ -173,9 +173,9 @@ async def run_task( # noqa: C901, WPS210
173173
signature = self.task_signatures.get(message.task_name)
174174
if self.cli_args.no_parse:
175175
signature = None
176-
parse_params(signature, message)
176+
parse_params(signature, self.task_hints.get(message.task_name) or {}, message)
177177
inject_context(
178-
self.task_signatures.get(message.task_name),
178+
self.task_hints.get(message.task_name) or {},
179179
message,
180180
self.broker,
181181
)

taskiq/cli/tests/test_context.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import inspect
1+
from typing import get_type_hints
22

33
from taskiq.cli.receiver import inject_context
44
from taskiq.context import Context
@@ -20,7 +20,36 @@ def func(param1: int, ctx: Context) -> int:
2020
)
2121

2222
inject_context(
23-
inspect.signature(func),
23+
get_type_hints(func),
24+
message=message,
25+
broker=None, # type: ignore
26+
)
27+
28+
assert message.kwargs.get("ctx")
29+
assert isinstance(message.kwargs["ctx"], Context)
30+
31+
32+
def test_inject_context_success_string_annotation() -> None:
33+
"""
34+
Test that context variable is injected as expected.
35+
36+
This test checks that if Context was provided as
37+
string, then everything is work as expected.
38+
"""
39+
40+
def func(param1: int, ctx: "Context") -> int:
41+
return param1
42+
43+
message = TaskiqMessage(
44+
task_id="",
45+
task_name="",
46+
labels={},
47+
args=[1],
48+
kwargs={},
49+
)
50+
51+
inject_context(
52+
get_type_hints(func),
2453
message=message,
2554
broker=None, # type: ignore
2655
)
@@ -44,7 +73,7 @@ def func(param1: int, ctx) -> int: # type: ignore
4473
)
4574

4675
inject_context(
47-
inspect.signature(func),
76+
get_type_hints(func),
4877
message=message,
4978
broker=None, # type: ignore
5079
)
@@ -71,7 +100,7 @@ def func(param1: int) -> int:
71100
)
72101

73102
inject_context(
74-
inspect.signature(func),
103+
get_type_hints(func),
75104
message=message,
76105
broker=None, # type: ignore
77106
)

taskiq/cli/tests/test_parameters_parsing.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
from dataclasses import dataclass
3-
from typing import Any, Type
3+
from typing import Any, Type, get_type_hints
44

55
import pytest
66
from pydantic import BaseModel
@@ -30,6 +30,7 @@ def test_parse_params_no_signature() -> None:
3030
modify_msg = src_msg.copy(deep=True)
3131
parse_params(
3232
signature=None,
33+
type_hints={},
3334
message=modify_msg,
3435
)
3536

@@ -51,7 +52,11 @@ def test_func(param: test_class) -> test_class: # type: ignore
5152
kwargs={},
5253
)
5354

54-
parse_params(inspect.signature(test_func), msg_with_args)
55+
parse_params(
56+
inspect.signature(test_func),
57+
get_type_hints(test_func),
58+
msg_with_args,
59+
)
5560

5661
assert isinstance(msg_with_args.args[0], test_class)
5762
assert msg_with_args.args[0].field == "test_val"
@@ -64,7 +69,11 @@ def test_func(param: test_class) -> test_class: # type: ignore
6469
kwargs={"param": {"field": "test_val"}},
6570
)
6671

67-
parse_params(inspect.signature(test_func), msg_with_kwargs)
72+
parse_params(
73+
inspect.signature(test_func),
74+
get_type_hints(test_func),
75+
msg_with_kwargs,
76+
)
6877

6978
assert isinstance(msg_with_kwargs.kwargs["param"], test_class)
7079
assert msg_with_kwargs.kwargs["param"].field == "test_val"
@@ -85,7 +94,11 @@ def test_func(param: test_class) -> test_class: # type: ignore
8594
kwargs={},
8695
)
8796

88-
parse_params(inspect.signature(test_func), msg_with_args)
97+
parse_params(
98+
inspect.signature(test_func),
99+
get_type_hints(test_func),
100+
msg_with_args,
101+
)
89102

90103
assert isinstance(msg_with_args.args[0], dict)
91104

@@ -97,7 +110,11 @@ def test_func(param: test_class) -> test_class: # type: ignore
97110
kwargs={"param": {"unknown": "unknown"}},
98111
)
99112

100-
parse_params(inspect.signature(test_func), msg_with_kwargs)
113+
parse_params(
114+
inspect.signature(test_func),
115+
get_type_hints(test_func),
116+
msg_with_kwargs,
117+
)
101118

102119
assert isinstance(msg_with_kwargs.kwargs["param"], dict)
103120

@@ -117,7 +134,7 @@ def test_func(param: test_class) -> test_class: # type: ignore
117134
kwargs={},
118135
)
119136

120-
parse_params(inspect.signature(test_func), msg_with_args)
137+
parse_params(inspect.signature(test_func), get_type_hints(test_func), msg_with_args)
121138

122139
assert msg_with_args.args[0] is None
123140

@@ -129,6 +146,10 @@ def test_func(param: test_class) -> test_class: # type: ignore
129146
kwargs={"param": None},
130147
)
131148

132-
parse_params(inspect.signature(test_func), msg_with_kwargs)
149+
parse_params(
150+
inspect.signature(test_func),
151+
get_type_hints(test_func),
152+
msg_with_kwargs,
153+
)
133154

134155
assert msg_with_kwargs.kwargs["param"] is None

0 commit comments

Comments
 (0)