Skip to content

Commit 7dd09a2

Browse files
committed
Implement pydantic from_payload and __get_pydantic_core_schema__
1 parent 290d9da commit 7dd09a2

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

temporalio/contrib/pydantic/converter.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
import json
2-
from typing import Any, Optional
2+
import typing
3+
from typing import Any, Optional, Type
34

5+
from pydantic import BaseModel, create_model
46
from pydantic.json import pydantic_encoder
7+
8+
import temporalio.workflow
59
from temporalio.api.common.v1 import Payload
610
from temporalio.converter import (
711
CompositePayloadConverter,
812
DataConverter,
913
DefaultPayloadConverter,
1014
JSONPlainPayloadConverter,
1115
)
16+
from temporalio.worker.workflow_sandbox._restrictions import (
17+
RestrictionContext,
18+
)
1219

1320

1421
class PydanticJSONPayloadConverter(JSONPlainPayloadConverter):
@@ -33,6 +40,26 @@ def to_payload(self, value: Any) -> Optional[Payload]:
3340
).encode(),
3441
)
3542

43+
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
44+
data = json.loads(payload.data.decode())
45+
if type_hint and typing.get_origin(type_hint) is list:
46+
assert isinstance(data, list), "Expected list"
47+
[type_hint] = typing.get_args(type_hint)
48+
assert type_hint is not None, "Expected type hint"
49+
assert issubclass(type_hint, BaseModel), "Expected BaseModel"
50+
if temporalio.workflow.unsafe.in_sandbox():
51+
type_hint = _unwrap_restricted_fields(type_hint)
52+
53+
return [self._from_dict(d, type_hint) for d in data]
54+
return self._from_dict(data, type_hint)
55+
56+
def _from_dict(self, data: dict, type_hint: Optional[Type]) -> Any:
57+
assert isinstance(data, dict), "Expected dict"
58+
if type_hint and hasattr(type_hint, "validate"):
59+
return type_hint.validate(data)
60+
61+
return data
62+
3663

3764
class PydanticPayloadConverter(CompositePayloadConverter):
3865
"""Payload converter that replaces Temporal JSON conversion with Pydantic
@@ -54,3 +81,13 @@ def __init__(self) -> None:
5481
payload_converter_class=PydanticPayloadConverter
5582
)
5683
"""Data converter using Pydantic JSON conversion."""
84+
85+
86+
def _unwrap_restricted_fields(
87+
model: Type[BaseModel],
88+
) -> Type[BaseModel]:
89+
fields = {
90+
name: (RestrictionContext.unwrap_if_proxied(f.annotation), f)
91+
for name, f in model.model_fields.items()
92+
}
93+
return create_model(model.__name__, **fields) # type: ignore

temporalio/worker/workflow_sandbox/_restrictions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
cast,
3232
)
3333

34+
from pydantic import GetCoreSchemaHandler
35+
from pydantic_core import CoreSchema, core_schema
36+
3437
import temporalio.workflow
3538

3639
logger = logging.getLogger(__name__)
@@ -948,6 +951,14 @@ def _is_restrictable(v: Any) -> bool:
948951

949952

950953
class _RestrictedProxy:
954+
@classmethod
955+
def __get_pydantic_core_schema__(
956+
cls, source_type: Any, handler: GetCoreSchemaHandler
957+
) -> CoreSchema:
958+
return core_schema.no_info_after_validator_function(
959+
cls, handler(RestrictionContext.unwrap_if_proxied(source_type))
960+
)
961+
951962
def __init__(self, *args, **kwargs) -> None:
952963
# When we instantiate this class, we have the signature of:
953964
# __init__(
@@ -971,6 +982,8 @@ def __init__(self, *args, **kwargs) -> None:
971982
_trace("__init__ unrecognized with args %s", args)
972983

973984
def __getattribute__(self, __name: str) -> Any:
985+
if __name == "__get_pydantic_core_schema__":
986+
return object.__getattribute__(self, "__get_pydantic_core_schema__")
974987
state = _RestrictionState.from_proxy(self)
975988
_trace("__getattribute__ %s on %s", __name, state.name)
976989
# We do not restrict __spec__ or __name__

0 commit comments

Comments
 (0)