Skip to content

Commit 9235a09

Browse files
committed
Added result parsing on return.
1 parent 826c31e commit 9235a09

File tree

5 files changed

+33
-2
lines changed

5 files changed

+33
-2
lines changed

taskiq/abc/broker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
Optional,
1919
TypeVar,
2020
Union,
21+
get_type_hints,
2122
overload,
2223
)
2324
from uuid import uuid4
2425

26+
from pydantic import TypeAdapter
2527
from typing_extensions import ParamSpec, Self, TypeAlias
2628

2729
from taskiq.abc.middleware import TaskiqMiddleware
@@ -326,12 +328,18 @@ def inner(
326328
inner_task_name = f"{fmodule}:{fname}"
327329
wrapper = wraps(func)
328330

331+
sign = get_type_hints(func)
332+
return_type = None
333+
if "return" in sign:
334+
return_type = TypeAdapter(sign["return"])
335+
329336
decorated_task = wrapper(
330337
self.decorator_class(
331338
broker=self,
332339
original_func=func,
333340
labels=inner_labels,
334341
task_name=inner_task_name,
342+
return_type=return_type,
335343
),
336344
)
337345

taskiq/brokers/shared_broker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def kicker(self) -> AsyncKicker[_Params, _ReturnType]:
3030
task_name=self.task_name,
3131
broker=broker,
3232
labels=self.labels,
33+
return_type=self.return_type,
3334
)
3435

3536

taskiq/decor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
Callable,
88
Dict,
99
Generic,
10+
Optional,
1011
TypeVar,
1112
Union,
1213
overload,
1314
)
1415

16+
from pydantic import TypeAdapter
1517
from typing_extensions import ParamSpec
1618

1719
from taskiq.kicker import AsyncKicker
@@ -50,11 +52,13 @@ def __init__(
5052
task_name: str,
5153
original_func: Callable[_FuncParams, _ReturnType],
5254
labels: Dict[str, Any],
55+
return_type: Optional[TypeAdapter[_ReturnType]] = None,
5356
) -> None:
5457
self.broker = broker
5558
self.task_name = task_name
5659
self.original_func = original_func
5760
self.labels = labels
61+
self.return_type = return_type
5862

5963
# Docs for this method are omitted in order to help
6064
# your IDE resolve correct docs for it.
@@ -172,6 +176,7 @@ def kicker(self) -> AsyncKicker[_FuncParams, _ReturnType]:
172176
task_name=self.task_name,
173177
broker=self.broker,
174178
labels=self.labels,
179+
return_type=self.return_type,
175180
)
176181

177182
def __repr__(self) -> str:

taskiq/kicker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
overload,
1515
)
1616

17-
from pydantic import BaseModel
17+
from pydantic import BaseModel, TypeAdapter
1818
from typing_extensions import ParamSpec
1919

2020
from taskiq.abc.middleware import TaskiqMiddleware
@@ -46,12 +46,14 @@ def __init__(
4646
task_name: str,
4747
broker: "AsyncBroker",
4848
labels: Dict[str, Any],
49+
return_type: Optional[TypeAdapter[_ReturnType]] = None,
4950
) -> None:
5051
self.task_name = task_name
5152
self.broker = broker
5253
self.labels = labels
5354
self.custom_task_id: Optional[str] = None
5455
self.custom_schedule_id: Optional[str] = None
56+
self.return_type = return_type
5557

5658
def with_labels(
5759
self,
@@ -169,6 +171,7 @@ async def kiq(
169171
return AsyncTaskiqTask(
170172
task_id=message.task_id,
171173
result_backend=self.broker.result_backend,
174+
return_type=self.return_type, # type: ignore # (pyright issue)
172175
)
173176

174177
async def schedule_by_cron(

taskiq/task.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
2+
from logging import getLogger
23
from time import time
34
from typing import TYPE_CHECKING, Any, Generic, Optional
45

6+
from pydantic import TypeAdapter
57
from typing_extensions import TypeVar
68

79
from taskiq.exceptions import (
@@ -15,6 +17,8 @@
1517
from taskiq.depends.progress_tracker import TaskProgress
1618
from taskiq.result import TaskiqResult
1719

20+
logger = getLogger("taskiq.task")
21+
1822
_ReturnType = TypeVar("_ReturnType")
1923

2024

@@ -25,9 +29,11 @@ def __init__(
2529
self,
2630
task_id: str,
2731
result_backend: "AsyncResultBackend[_ReturnType]",
32+
return_type: Optional[TypeAdapter[_ReturnType]] = None,
2833
) -> None:
2934
self.task_id = task_id
3035
self.result_backend = result_backend
36+
self.return_type = return_type
3137

3238
async def is_ready(self) -> bool:
3339
"""
@@ -53,10 +59,18 @@ async def get_result(self, with_logs: bool = False) -> "TaskiqResult[_ReturnType
5359
:return: task's return value.
5460
"""
5561
try:
56-
return await self.result_backend.get_result(
62+
res = await self.result_backend.get_result(
5763
self.task_id,
5864
with_logs=with_logs,
5965
)
66+
if self.return_type is not None:
67+
try:
68+
res.return_value = self.return_type.validate_python(
69+
res.return_value,
70+
)
71+
except ValueError:
72+
logger.warning("Cannot parse return type into %s", self.return_type)
73+
return res
6074
except Exception as exc:
6175
raise ResultGetError from exc
6276

0 commit comments

Comments
 (0)