Skip to content

Commit 333c30f

Browse files
committed
Added serializer, removed bug with duplicate key
Signed-off-by: chandr-andr (Kiselev Aleksandr) <[email protected]>
1 parent 9659e78 commit 333c30f

File tree

4 files changed

+58
-10
lines changed

4 files changed

+58
-10
lines changed

taskiq_psqlpy/queries.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
INSERT_RESULT_QUERY = """
1313
INSERT INTO {} VALUES ($1, $2)
14+
ON CONFLICT (task_id)
15+
DO UPDATE
16+
SET result = $2
1417
"""
1518

1619
IS_RESULT_EXISTS_QUERY = """

taskiq_psqlpy/result_backend.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pickle
21
from typing import (
32
Any,
43
Final,
@@ -11,6 +10,9 @@
1110
from psqlpy import ConnectionPool
1211
from psqlpy.exceptions import RustPSQLDriverPyBaseError
1312
from taskiq import AsyncResultBackend, TaskiqResult
13+
from taskiq.abc.serializer import TaskiqSerializer
14+
from taskiq.compat import model_dump, model_validate
15+
from taskiq.serializers import PickleSerializer
1416

1517
from taskiq_psqlpy.exceptions import ResultIsMissingError
1618
from taskiq_psqlpy.queries import (
@@ -34,19 +36,24 @@ def __init__(
3436
keep_results: bool = True,
3537
table_name: str = "taskiq_results",
3638
field_for_task_id: Literal["VarChar", "Text"] = "VarChar",
39+
serializer: Optional[TaskiqSerializer] = None,
3740
**connect_kwargs: Any,
3841
) -> None:
3942
"""Construct new result backend.
4043
4144
:param dsn: connection string to PostgreSQL.
4245
:param keep_results: flag to not remove results from Redis after reading.
46+
:param table_name: name of the table to store results.
47+
:param field_for_task_id: type of the field to store task_id.
48+
:param serializer: serializer class to serialize/deserialize result from task.
4349
:param connect_kwargs: additional arguments for nats `ConnectionPool` class.
4450
"""
4551
self.dsn: Final = dsn
4652
self.keep_results: Final = keep_results
4753
self.table_name: Final = table_name
4854
self.field_for_task_id: Final = field_for_task_id
4955
self.connect_kwargs: Final = connect_kwargs
56+
self.serializer = serializer or PickleSerializer()
5057

5158
self._database_pool: ConnectionPool
5259

@@ -93,7 +100,7 @@ async def set_result(
93100
),
94101
parameters=[
95102
task_id,
96-
pickle.dumps(result),
103+
self.serializer.dumpb(model_dump(result)),
97104
],
98105
)
99106

@@ -149,8 +156,9 @@ async def get_result(
149156
parameters=[task_id],
150157
)
151158

152-
taskiq_result: Final = pickle.loads( # noqa: S301
153-
result_in_bytes,
159+
taskiq_result: Final = model_validate(
160+
TaskiqResult[_ReturnType],
161+
self.serializer.loadb(result_in_bytes),
154162
)
155163

156164
if not with_logs:

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def postgresql_dsn() -> str:
4545
"""
4646
return (
4747
os.environ.get("POSTGRESQL_URL")
48-
or "postgresql://postgres:postgres@localhost:5432/taskiqpsqlpy"
48+
or "postgresql://akiselev:12345@localhost:5432/taskiqpsqlpy"
4949
)
5050

5151

tests/test_result_backend.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, TypeVar
33

44
import pytest
5+
from pydantic import BaseModel
56
from taskiq import TaskiqResult
67

78
from taskiq_psqlpy.exceptions import ResultIsMissingError
@@ -11,12 +12,10 @@
1112
pytestmark = pytest.mark.anyio
1213

1314

14-
class ResultForTest:
15+
class ResultForTest(BaseModel):
1516
"""Just test class for testing."""
1617

17-
def __init__(self) -> None:
18-
"""Generates test class for result testing."""
19-
self.test_arg = uuid.uuid4()
18+
test_arg: str = uuid.uuid4().hex
2019

2120

2221
@pytest.fixture
@@ -137,7 +136,7 @@ async def test_success_backend_custom_result(
137136
result = await psqlpy_result_backend.get_result(task_id=task_id)
138137

139138
assert (
140-
result.return_value.test_arg # type: ignore
139+
result.return_value["test_arg"] # type: ignore
141140
== custom_taskiq_result.return_value.test_arg # type: ignore
142141
)
143142
assert result.is_err == custom_taskiq_result.is_err
@@ -158,3 +157,41 @@ async def test_success_backend_is_result_ready(
158157
)
159158

160159
assert await psqlpy_result_backend.is_result_ready(task_id=task_id)
160+
161+
162+
async def test_test_success_backend_custom_result_set_same_task_id(
163+
psqlpy_result_backend: PSQLPyResultBackend[_ReturnType],
164+
custom_taskiq_result: TaskiqResult[_ReturnType],
165+
task_id: str,
166+
) -> None:
167+
await psqlpy_result_backend.set_result(
168+
task_id=task_id,
169+
result=custom_taskiq_result,
170+
)
171+
result = await psqlpy_result_backend.get_result(task_id=task_id)
172+
173+
assert (
174+
result.return_value["test_arg"] # type: ignore
175+
== custom_taskiq_result.return_value.test_arg # type: ignore
176+
)
177+
178+
await psqlpy_result_backend.set_result(
179+
task_id=task_id,
180+
result=custom_taskiq_result,
181+
)
182+
183+
another_taskiq_res_uuid = uuid.uuid4().hex
184+
another_taskiq_res = TaskiqResult(
185+
is_err=False,
186+
log=None,
187+
return_value=ResultForTest(test_arg=another_taskiq_res_uuid),
188+
execution_time=0.1,
189+
)
190+
191+
await psqlpy_result_backend.set_result(
192+
task_id=task_id,
193+
result=another_taskiq_res, # type: ignore[arg-type]
194+
)
195+
result = await psqlpy_result_backend.get_result(task_id=task_id)
196+
197+
assert result.return_value["test_arg"] == another_taskiq_res_uuid # type: ignore

0 commit comments

Comments
 (0)