Skip to content

Commit 18b7ea3

Browse files
committed
Added tests.
1 parent ba55f43 commit 18b7ea3

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

tests/test_task.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import uuid
2+
from dataclasses import dataclass
3+
from typing import Dict, TypeVar
4+
5+
import pytest
6+
7+
from taskiq import serializers
8+
from taskiq.abc import AsyncResultBackend
9+
from taskiq.abc.serializer import TaskiqSerializer
10+
from taskiq.compat import model_dump, model_validate
11+
from taskiq.result.v1 import TaskiqResult
12+
from taskiq.task import AsyncTaskiqTask
13+
14+
_ReturnType = TypeVar("_ReturnType")
15+
16+
17+
class SerializingBackend(AsyncResultBackend[_ReturnType]):
18+
def __init__(self, serializer: TaskiqSerializer) -> None:
19+
self._serializer = serializer
20+
self._results: Dict[str, bytes] = {}
21+
22+
async def set_result(
23+
self,
24+
task_id: str,
25+
result: TaskiqResult[_ReturnType], # type: ignore
26+
) -> None:
27+
"""Set result with dumping."""
28+
self._results[task_id] = self._serializer.dumpb(model_dump(result))
29+
30+
async def is_result_ready(self, task_id: str) -> bool:
31+
"""Check if result is ready."""
32+
return task_id in self._results
33+
34+
async def get_result(
35+
self,
36+
task_id: str,
37+
with_logs: bool = False,
38+
) -> TaskiqResult[_ReturnType]:
39+
"""Get result with loading."""
40+
data = self._results[task_id]
41+
return model_validate(TaskiqResult, self._serializer.loadb(data))
42+
43+
44+
@pytest.mark.parametrize(
45+
"serializer",
46+
[
47+
serializers.MSGPackSerializer(),
48+
serializers.CBORSerializer(),
49+
serializers.PickleSerializer(),
50+
serializers.JSONSerializer(),
51+
],
52+
)
53+
@pytest.mark.anyio
54+
async def test_res_parsing_success(serializer: TaskiqSerializer) -> None:
55+
@dataclass
56+
class MyResult:
57+
name: str
58+
age: int
59+
60+
res = MyResult(name="test", age=10)
61+
res_back: AsyncResultBackend[MyResult] = SerializingBackend(serializer)
62+
test_id = str(uuid.uuid4())
63+
await res_back.set_result(
64+
test_id,
65+
TaskiqResult(
66+
is_err=False,
67+
return_value=res,
68+
execution_time=0.0,
69+
),
70+
)
71+
sent_task = AsyncTaskiqTask(test_id, res_back, MyResult)
72+
parsed = await sent_task.wait_result()
73+
assert isinstance(parsed.return_value, MyResult)

0 commit comments

Comments
 (0)