Skip to content

Commit f323fcd

Browse files
authored
feat: upgrade to pydantic v2 (#160)
1 parent a9705a8 commit f323fcd

File tree

11 files changed

+399
-144
lines changed

11 files changed

+399
-144
lines changed

poetry.lock

Lines changed: 219 additions & 125 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ keywords = ["taskiq", "tasks", "distributed", "async"]
2828
[tool.poetry.dependencies]
2929
python = "^3.8.1"
3030
typing-extensions = ">=3.10.0.0"
31-
pydantic = "^1.6.2"
31+
pydantic = ">=1.0,<=3.0"
3232
importlib-metadata = "*"
3333
pycron = "^3.0.0"
3434
taskiq_dependencies = "^1"

taskiq/compat.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# flake8: noqa
2+
from typing import Any, Dict, Optional, Type, TypeVar, Union
3+
4+
import pydantic
5+
from importlib_metadata import version
6+
from packaging.version import Version, parse
7+
8+
PYDANTIC_VER = parse(version("pydantic"))
9+
10+
Model = TypeVar("Model", bound="pydantic.BaseModel")
11+
12+
13+
if PYDANTIC_VER >= Version("2.0"):
14+
T = TypeVar("T")
15+
16+
def parse_obj_as(annot: T, obj: Any) -> T:
17+
return pydantic.TypeAdapter(annot).validate_python(obj)
18+
19+
def model_validate_json(
20+
model_class: Type[Model],
21+
message: Union[str, bytes, bytearray],
22+
) -> Model:
23+
return model_class.model_validate_json(message)
24+
25+
def model_dump_json(instance: Model) -> str:
26+
return instance.model_dump_json()
27+
28+
def model_copy(
29+
instance: Model,
30+
update: Optional[Dict[str, Any]] = None,
31+
deep: bool = False,
32+
) -> Model:
33+
return instance.model_copy(update=update, deep=deep)
34+
35+
else:
36+
parse_obj_as = pydantic.parse_obj_as # type: ignore
37+
38+
def model_validate_json(
39+
model_class: Type[Model],
40+
message: Union[str, bytes, bytearray],
41+
) -> Model:
42+
return model_class.parse_raw(message)
43+
44+
def model_dump_json(instance: Model) -> str:
45+
return instance.json()
46+
47+
def model_copy(
48+
instance: Model,
49+
update: Optional[Dict[str, Any]] = None,
50+
deep: bool = False,
51+
) -> Model:
52+
return instance.copy(update=update, deep=deep)

taskiq/formatters/json_formatter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from taskiq.abc.formatter import TaskiqFormatter
2+
from taskiq.compat import model_dump_json, model_validate_json
23
from taskiq.message import BrokerMessage, TaskiqMessage
34

45

@@ -15,7 +16,7 @@ def dumps(self, message: TaskiqMessage) -> BrokerMessage:
1516
return BrokerMessage(
1617
task_id=message.task_id,
1718
task_name=message.task_name,
18-
message=message.json().encode(),
19+
message=model_dump_json(message).encode(),
1920
labels=message.labels,
2021
)
2122

@@ -26,4 +27,4 @@ def loads(self, message: bytes) -> TaskiqMessage:
2627
:param message: broker's message.
2728
:return: parsed taskiq message.
2829
"""
29-
return TaskiqMessage.parse_raw(message)
30+
return model_validate_json(TaskiqMessage, message)

taskiq/receiver/params_parser.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
from logging import getLogger
33
from typing import Any, Dict, Optional
44

5-
from pydantic import parse_obj_as
6-
5+
from taskiq.compat import parse_obj_as
76
from taskiq.message import TaskiqMessage
87

98
logger = getLogger(__name__)

taskiq/result/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# flake8: noqa
2+
from packaging.version import Version
3+
4+
from taskiq.compat import PYDANTIC_VER
5+
6+
if PYDANTIC_VER >= Version("2.0"):
7+
from .v2 import TaskiqResult
8+
else:
9+
from .v1 import TaskiqResult
10+
11+
12+
__all__ = [
13+
"TaskiqResult",
14+
]
File renamed without changes.

taskiq/result/v2.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import json
2+
import pickle # noqa: S403
3+
from typing import Any, Dict, Generic, Optional, TypeVar
4+
5+
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
6+
from typing_extensions import Self
7+
8+
from taskiq.serialization import exception_to_python, prepare_exception
9+
10+
_ReturnType = TypeVar("_ReturnType")
11+
12+
13+
class TaskiqResult(BaseModel, Generic[_ReturnType]):
14+
"""Result of a remote task invocation."""
15+
16+
is_err: bool
17+
# Log is a deprecated field. It would be removed in future
18+
# releases of not, if we find a way to capture logs in async
19+
# environment.
20+
log: Optional[str] = None
21+
return_value: _ReturnType
22+
execution_time: float
23+
labels: Dict[str, str] = Field(default_factory=dict)
24+
25+
error: Optional[BaseException] = None
26+
27+
model_config = ConfigDict(arbitrary_types_allowed=True)
28+
29+
@field_serializer("error")
30+
def serialize_error(self, value: BaseException) -> Any:
31+
"""
32+
Serialize error field.
33+
34+
:returns: Any
35+
:param value: exception to serialize.
36+
"""
37+
if value:
38+
return prepare_exception(value, json)
39+
40+
return None
41+
42+
def raise_for_error(self) -> "Self":
43+
"""Raise exception if `error`.
44+
45+
:raises error: task execution exception
46+
:returns: TaskiqResult
47+
"""
48+
if self.error is not None:
49+
raise self.error
50+
return self
51+
52+
def __getstate__(self) -> Dict[Any, Any]:
53+
dict = super().__getstate__() # noqa: WPS125
54+
vals: Dict[str, Any] = dict["__dict__"]
55+
56+
if "error" in vals and vals["error"] is not None:
57+
vals["error"] = prepare_exception(
58+
vals["error"],
59+
pickle,
60+
)
61+
62+
return dict
63+
64+
@field_validator("error", mode="before")
65+
@classmethod
66+
def _validate_error(cls, value: Any) -> Optional[BaseException]:
67+
return exception_to_python(value)

taskiq/serialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def _prepare_exception(
297297
SEEN_EXCEPTIONS_CACHE.discard(id(exc))
298298

299299

300-
@pydantic.validate_arguments(config={"arbitrary_types_allowed": True})
300+
@pydantic.validate_arguments(config={"arbitrary_types_allowed": True}) # type: ignore
301301
def prepare_exception(
302302
exc: BaseException,
303303
coder: Coder[Any, Any],
@@ -312,7 +312,7 @@ def prepare_exception(
312312
return _prepare_exception(exc, coder) # type: ignore
313313

314314

315-
@pydantic.validate_arguments(config={"arbitrary_types_allowed": True})
315+
@pydantic.validate_arguments(config={"arbitrary_types_allowed": True}) # type: ignore
316316
def exception_to_python( # noqa: C901, WPS210
317317
exc: Optional[Union[BaseException, ExceptionRepr]],
318318
) -> Optional[BaseException]:

tests/cli/worker/test_parameters_parsing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
from pydantic import BaseModel
77

8+
from taskiq.compat import model_copy
89
from taskiq.message import TaskiqMessage
910
from taskiq.receiver.params_parser import parse_params
1011

@@ -27,7 +28,7 @@ def test_parse_params_no_signature() -> None:
2728
args=[1, 2],
2829
kwargs={"a": 1},
2930
)
30-
modify_msg = src_msg.copy(deep=True)
31+
modify_msg = model_copy(src_msg, deep=True)
3132
parse_params(
3233
signature=None,
3334
type_hints={},

0 commit comments

Comments
 (0)