Skip to content

Commit 9267198

Browse files
committed
Implement type converter and __get_pydantic_core_schema__
1 parent 9b8524a commit 9267198

File tree

3 files changed

+64
-18
lines changed

3 files changed

+64
-18
lines changed

temporalio/contrib/pydantic/converter.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,45 @@
1+
import inspect
12
import json
2-
from typing import Any, Optional
3+
from typing import (
4+
Any,
5+
Optional,
6+
Type,
7+
)
38

9+
import pydantic
410
from pydantic.json import pydantic_encoder
11+
12+
import temporalio.workflow
513
from temporalio.api.common.v1 import Payload
614
from temporalio.converter import (
715
CompositePayloadConverter,
816
DataConverter,
917
DefaultPayloadConverter,
1018
JSONPlainPayloadConverter,
19+
JSONTypeConverter,
1120
)
21+
from temporalio.worker.workflow_sandbox._restrictions import RestrictionContext
1222

1323

1424
class PydanticJSONPayloadConverter(JSONPlainPayloadConverter):
1525
"""Pydantic JSON payload converter.
1626
17-
This extends the :py:class:`JSONPlainPayloadConverter` to override
18-
:py:meth:`to_payload` using the Pydantic encoder.
27+
Extends :py:class:`JSONPlainPayloadConverter` to override :py:meth:`to_payload` using
28+
the Pydantic encoder. :py:meth:`from_payload` uses the parent implementation, with a
29+
custom type converter.
1930
"""
2031

32+
def __init__(self) -> None:
33+
super().__init__(custom_type_converters=[PydanticModelTypeConverter()])
34+
2135
def to_payload(self, value: Any) -> Optional[Payload]:
2236
"""Convert all values with Pydantic encoder or fail.
2337
2438
Like the base class, we fail if we cannot convert. This payload
2539
converter is expected to be the last in the chain, so it can fail if
2640
unable to convert.
2741
"""
28-
# We let JSON conversion errors be thrown to caller
42+
# Let JSON conversion errors be thrown to caller
2943
return Payload(
3044
metadata={"encoding": self.encoding.encode()},
3145
data=json.dumps(
@@ -34,17 +48,47 @@ def to_payload(self, value: Any) -> Optional[Payload]:
3448
)
3549

3650

51+
class PydanticModelTypeConverter(JSONTypeConverter):
52+
def to_typed_value(self, hint: Type, value: Any) -> Any:
53+
if not inspect.isclass(hint) or not issubclass(hint, pydantic.BaseModel):
54+
return JSONTypeConverter.Unhandled
55+
model = hint
56+
if not isinstance(value, dict):
57+
raise TypeError(
58+
f"Cannot convert to {model}, value is {type(value)} not dict"
59+
)
60+
if temporalio.workflow.unsafe.in_sandbox():
61+
# Unwrap proxied model field types so that Pydantic can call their constructors
62+
model = pydantic.create_model(
63+
model.__name__,
64+
**{ # type: ignore
65+
name: (RestrictionContext.unwrap_if_proxied(f.annotation), f)
66+
for name, f in model.model_fields.items()
67+
},
68+
)
69+
if hasattr(model, "model_validate"):
70+
return model.model_validate(value)
71+
elif hasattr(model, "parse_obj"):
72+
# Pydantic v1
73+
return model.parse_obj(value)
74+
else:
75+
raise ValueError(
76+
f"{model} is a Pydantic model but does not have a `model_validate` or `parse_obj` method"
77+
)
78+
79+
3780
class PydanticPayloadConverter(CompositePayloadConverter):
3881
"""Payload converter that replaces Temporal JSON conversion with Pydantic
3982
JSON conversion.
4083
"""
4184

4285
def __init__(self) -> None:
86+
json_payload_converter = PydanticJSONPayloadConverter()
4387
super().__init__(
4488
*(
4589
c
4690
if not isinstance(c, JSONPlainPayloadConverter)
47-
else PydanticJSONPayloadConverter()
91+
else json_payload_converter
4892
for c in DefaultPayloadConverter.default_encoding_payload_converters
4993
)
5094
)

temporalio/converter.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -558,9 +558,10 @@ def encoding(self) -> str:
558558
def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]:
559559
"""See base class."""
560560
# Check for pydantic then send warning
561+
# TODO (dan): update
561562
if hasattr(value, "parse_obj"):
562563
warnings.warn(
563-
"If you're using pydantic model, refer to https://github.com/temporalio/samples-python/tree/main/pydantic_converter for better support"
564+
"If you're using a pydantic model, refer to https://github.com/temporalio/samples-python/tree/main/pydantic_converter for better support"
564565
)
565566
# We let JSON conversion errors be thrown to caller
566567
return temporalio.api.common.v1.Payload(
@@ -1522,18 +1523,6 @@ def value_to_type(
15221523
# TODO(cretz): Want way to convert snake case to camel case?
15231524
return hint(**field_values)
15241525

1525-
# If there is a @staticmethod or @classmethod parse_obj, we will use it.
1526-
# This covers Pydantic models.
1527-
parse_obj_attr = inspect.getattr_static(hint, "parse_obj", None)
1528-
if isinstance(parse_obj_attr, classmethod) or isinstance(
1529-
parse_obj_attr, staticmethod
1530-
):
1531-
if not isinstance(value, dict):
1532-
raise TypeError(
1533-
f"Cannot convert to {hint}, value is {type(value)} not dict"
1534-
)
1535-
return getattr(hint, "parse_obj")(value)
1536-
15371526
# IntEnum
15381527
if inspect.isclass(hint) and issubclass(hint, IntEnum):
15391528
if not isinstance(value, int):

temporalio/worker/workflow_sandbox/_restrictions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
cast,
3232
)
3333

34+
from pydantic import GetCoreSchemaHandler
35+
from pydantic_core import CoreSchema, core_schema
36+
3437
import temporalio.workflow
3538

3639
logger = logging.getLogger(__name__)
@@ -948,6 +951,14 @@ def _is_restrictable(v: Any) -> bool:
948951

949952

950953
class _RestrictedProxy:
954+
@classmethod
955+
def __get_pydantic_core_schema__(
956+
cls, source_type: Any, handler: GetCoreSchemaHandler
957+
) -> CoreSchema:
958+
return core_schema.no_info_after_validator_function(
959+
cls, handler(RestrictionContext.unwrap_if_proxied(source_type))
960+
)
961+
951962
def __init__(self, *args, **kwargs) -> None:
952963
# When we instantiate this class, we have the signature of:
953964
# __init__(
@@ -971,6 +982,8 @@ def __init__(self, *args, **kwargs) -> None:
971982
_trace("__init__ unrecognized with args %s", args)
972983

973984
def __getattribute__(self, __name: str) -> Any:
985+
if __name == "__get_pydantic_core_schema__":
986+
return object.__getattribute__(self, "__get_pydantic_core_schema__")
974987
state = _RestrictionState.from_proxy(self)
975988
_trace("__getattribute__ %s on %s", __name, state.name)
976989
# We do not restrict __spec__ or __name__

0 commit comments

Comments
 (0)