Skip to content

Commit 1e09d52

Browse files
committed
Implement pydantic from_payload
1 parent 1e1c2c8 commit 1e09d52

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

temporalio/contrib/pydantic/converter.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import json
2-
from typing import Any, Optional
2+
import typing
3+
from typing import Any, Optional, Type
34

5+
from pydantic import BaseModel, create_model
46
from pydantic.json import pydantic_encoder
7+
8+
import temporalio.workflow
59
from temporalio.api.common.v1 import Payload
610
from temporalio.converter import (
711
CompositePayloadConverter,
812
DataConverter,
913
DefaultPayloadConverter,
1014
JSONPlainPayloadConverter,
1115
)
16+
from temporalio.worker.workflow_sandbox._restrictions import (
17+
_RestrictedProxy,
18+
_unwrap_restricted_proxy,
19+
)
1220

1321

1422
class PydanticJSONPayloadConverter(JSONPlainPayloadConverter):
@@ -33,6 +41,26 @@ def to_payload(self, value: Any) -> Optional[Payload]:
3341
).encode(),
3442
)
3543

44+
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
45+
data = json.loads(payload.data.decode())
46+
if type_hint and typing.get_origin(type_hint) is list:
47+
assert isinstance(data, list), "Expected list"
48+
[type_hint] = typing.get_args(type_hint)
49+
assert type_hint is not None, "Expected type hint"
50+
assert issubclass(type_hint, BaseModel), "Expected BaseModel"
51+
if temporalio.workflow.unsafe.in_sandbox():
52+
type_hint = _unwrap_restricted_fields(type_hint)
53+
54+
return [self._from_dict(d, type_hint) for d in data]
55+
return self._from_dict(data, type_hint)
56+
57+
def _from_dict(self, data: dict, type_hint: Optional[Type]) -> Any:
58+
assert isinstance(data, dict), "Expected dict"
59+
if type_hint and hasattr(type_hint, "validate"):
60+
return type_hint.validate(data)
61+
62+
return data
63+
3664

3765
class PydanticPayloadConverter(CompositePayloadConverter):
3866
"""Payload converter that replaces Temporal JSON conversion with Pydantic
@@ -56,3 +84,13 @@ def __init__(self) -> None:
5684
payload_converter_class=PydanticPayloadConverter
5785
)
5886
"""Data converter using Pydantic JSON conversion."""
87+
88+
89+
def _unwrap_restricted_fields(
90+
model: Type[BaseModel],
91+
) -> Type[BaseModel]:
92+
fields = {
93+
name: (_unwrap_restricted_proxy(f.annotation), f)
94+
for name, f in model.model_fields.items()
95+
}
96+
return create_model(model.__name__, **fields) # type: ignore

0 commit comments

Comments
 (0)