diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 69c9e25c..9c7fbe86 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -18,6 +18,7 @@ Optional, TypeVar, Union, + get_type_hints, overload, ) from uuid import uuid4 @@ -326,12 +327,18 @@ def inner( inner_task_name = f"{fmodule}:{fname}" wrapper = wraps(func) + sign = get_type_hints(func) + return_type = None + if "return" in sign: + return_type = sign["return"] + decorated_task = wrapper( self.decorator_class( broker=self, original_func=func, labels=inner_labels, task_name=inner_task_name, + return_type=return_type, # type: ignore ), ) diff --git a/taskiq/brokers/shared_broker.py b/taskiq/brokers/shared_broker.py index d6574e60..def2c797 100644 --- a/taskiq/brokers/shared_broker.py +++ b/taskiq/brokers/shared_broker.py @@ -30,6 +30,7 @@ def kicker(self) -> AsyncKicker[_Params, _ReturnType]: task_name=self.task_name, broker=broker, labels=self.labels, + return_type=self.return_type, ) diff --git a/taskiq/compat.py b/taskiq/compat.py index 1858d2c8..ce54bb9a 100644 --- a/taskiq/compat.py +++ b/taskiq/compat.py @@ -1,6 +1,6 @@ # flake8: noqa from functools import lru_cache -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, Dict, Hashable, Optional, Type, TypeVar, Union import pydantic from importlib_metadata import version @@ -12,13 +12,13 @@ IS_PYDANTIC2 = PYDANTIC_VER >= Version("2.0") if IS_PYDANTIC2: - T = TypeVar("T") + T = TypeVar("T", bound=Hashable) @lru_cache() - def create_type_adapter(annot: T) -> pydantic.TypeAdapter[T]: + def create_type_adapter(annot: Type[T]) -> pydantic.TypeAdapter[T]: return pydantic.TypeAdapter(annot) - def parse_obj_as(annot: T, obj: Any) -> T: + def parse_obj_as(annot: Type[T], obj: Any) -> T: return create_type_adapter(annot).validate_python(obj) def model_validate( diff --git a/taskiq/decor.py b/taskiq/decor.py index dcbb2de0..ba774506 100644 --- a/taskiq/decor.py +++ b/taskiq/decor.py @@ -7,6 +7,8 @@ Callable, Dict, Generic, + Optional, + Type, TypeVar, Union, overload, @@ -50,11 +52,13 @@ def __init__( task_name: str, original_func: Callable[_FuncParams, _ReturnType], labels: Dict[str, Any], + return_type: Optional[Type[_ReturnType]] = None, ) -> None: self.broker = broker self.task_name = task_name self.original_func = original_func self.labels = labels + self.return_type = return_type # Docs for this method are omitted in order to help # your IDE resolve correct docs for it. @@ -172,6 +176,7 @@ def kicker(self) -> AsyncKicker[_FuncParams, _ReturnType]: task_name=self.task_name, broker=self.broker, labels=self.labels, + return_type=self.return_type, ) def __repr__(self) -> str: diff --git a/taskiq/kicker.py b/taskiq/kicker.py index d2ff8e6e..9583df5d 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -9,6 +9,7 @@ Dict, Generic, Optional, + Type, TypeVar, Union, overload, @@ -46,12 +47,14 @@ def __init__( task_name: str, broker: "AsyncBroker", labels: Dict[str, Any], + return_type: Optional[Type[_ReturnType]] = None, ) -> None: self.task_name = task_name self.broker = broker self.labels = labels self.custom_task_id: Optional[str] = None self.custom_schedule_id: Optional[str] = None + self.return_type = return_type def with_labels( self, @@ -169,6 +172,7 @@ async def kiq( return AsyncTaskiqTask( task_id=message.task_id, result_backend=self.broker.result_backend, + return_type=self.return_type, # type: ignore # (pyright issue) ) async def schedule_by_cron( diff --git a/taskiq/task.py b/taskiq/task.py index c54e9d46..006dc520 100644 --- a/taskiq/task.py +++ b/taskiq/task.py @@ -1,9 +1,11 @@ import asyncio +from logging import getLogger from time import time -from typing import TYPE_CHECKING, Any, Generic, Optional +from typing import TYPE_CHECKING, Any, Generic, Optional, Type from typing_extensions import TypeVar +from taskiq.compat import parse_obj_as from taskiq.exceptions import ( ResultGetError, ResultIsReadyError, @@ -15,6 +17,8 @@ from taskiq.depends.progress_tracker import TaskProgress from taskiq.result import TaskiqResult +logger = getLogger("taskiq.task") + _ReturnType = TypeVar("_ReturnType") @@ -25,9 +29,11 @@ def __init__( self, task_id: str, result_backend: "AsyncResultBackend[_ReturnType]", + return_type: Optional[Type[_ReturnType]] = None, ) -> None: self.task_id = task_id self.result_backend = result_backend + self.return_type = return_type async def is_ready(self) -> bool: """ @@ -53,10 +59,19 @@ async def get_result(self, with_logs: bool = False) -> "TaskiqResult[_ReturnType :return: task's return value. """ try: - return await self.result_backend.get_result( + res = await self.result_backend.get_result( self.task_id, with_logs=with_logs, ) + if self.return_type is not None: + try: + res.return_value = parse_obj_as( + self.return_type, + res.return_value, + ) + except ValueError: + logger.warning("Cannot parse return type into %s", self.return_type) + return res except Exception as exc: raise ResultGetError from exc diff --git a/tests/test_task.py b/tests/test_task.py new file mode 100644 index 00000000..b0b76a91 --- /dev/null +++ b/tests/test_task.py @@ -0,0 +1,72 @@ +import uuid +from typing import Dict, TypeVar + +import pytest +from pydantic import BaseModel + +from taskiq import serializers +from taskiq.abc import AsyncResultBackend +from taskiq.abc.serializer import TaskiqSerializer +from taskiq.compat import model_dump, model_validate +from taskiq.result.v1 import TaskiqResult +from taskiq.task import AsyncTaskiqTask + +_ReturnType = TypeVar("_ReturnType") + + +class SerializingBackend(AsyncResultBackend[_ReturnType]): + def __init__(self, serializer: TaskiqSerializer) -> None: + self._serializer = serializer + self._results: Dict[str, bytes] = {} + + async def set_result( + self, + task_id: str, + result: TaskiqResult[_ReturnType], # type: ignore + ) -> None: + """Set result with dumping.""" + self._results[task_id] = self._serializer.dumpb(model_dump(result)) + + async def is_result_ready(self, task_id: str) -> bool: + """Check if result is ready.""" + return task_id in self._results + + async def get_result( + self, + task_id: str, + with_logs: bool = False, + ) -> TaskiqResult[_ReturnType]: + """Get result with loading.""" + data = self._results[task_id] + return model_validate(TaskiqResult, self._serializer.loadb(data)) + + +@pytest.mark.parametrize( + "serializer", + [ + serializers.MSGPackSerializer(), + serializers.CBORSerializer(), + serializers.PickleSerializer(), + serializers.JSONSerializer(), + ], +) +@pytest.mark.anyio +async def test_res_parsing_success(serializer: TaskiqSerializer) -> None: + class MyResult(BaseModel): + name: str + age: int + + res = MyResult(name="test", age=10) + res_back: AsyncResultBackend[MyResult] = SerializingBackend(serializer) + test_id = str(uuid.uuid4()) + await res_back.set_result( + test_id, + TaskiqResult( + is_err=False, + return_value=res, + execution_time=0.0, + ), + ) + sent_task = AsyncTaskiqTask(test_id, res_back, MyResult) + parsed = await sent_task.wait_result() + assert isinstance(parsed.return_value, MyResult)