11import json
2- from typing import Any , Optional
2+ import typing
3+ from typing import Any , Optional , Type
34
5+ from pydantic import BaseModel , create_model
46from pydantic .json import pydantic_encoder
7+
8+ import temporalio .workflow
59from temporalio .api .common .v1 import Payload
610from temporalio .converter import (
711 CompositePayloadConverter ,
812 DataConverter ,
913 DefaultPayloadConverter ,
1014 JSONPlainPayloadConverter ,
1115)
16+ from temporalio .worker .workflow_sandbox ._restrictions import (
17+ RestrictionContext ,
18+ )
1219
1320
1421class PydanticJSONPayloadConverter (JSONPlainPayloadConverter ):
@@ -33,6 +40,26 @@ def to_payload(self, value: Any) -> Optional[Payload]:
3340 ).encode (),
3441 )
3542
43+ def from_payload (self , payload : Payload , type_hint : Optional [Type ] = None ) -> Any :
44+ data = json .loads (payload .data .decode ())
45+ if type_hint and typing .get_origin (type_hint ) is list :
46+ assert isinstance (data , list ), "Expected list"
47+ [type_hint ] = typing .get_args (type_hint )
48+ assert type_hint is not None , "Expected type hint"
49+ assert issubclass (type_hint , BaseModel ), "Expected BaseModel"
50+ if temporalio .workflow .unsafe .in_sandbox ():
51+ type_hint = _unwrap_restricted_fields (type_hint )
52+
53+ return [self ._from_dict (d , type_hint ) for d in data ]
54+ return self ._from_dict (data , type_hint )
55+
56+ def _from_dict (self , data : dict , type_hint : Optional [Type ]) -> Any :
57+ assert isinstance (data , dict ), "Expected dict"
58+ if type_hint and hasattr (type_hint , "validate" ):
59+ return type_hint .validate (data )
60+
61+ return data
62+
3663
3764class PydanticPayloadConverter (CompositePayloadConverter ):
3865 """Payload converter that replaces Temporal JSON conversion with Pydantic
@@ -54,3 +81,13 @@ def __init__(self) -> None:
5481 payload_converter_class = PydanticPayloadConverter
5582)
5683"""Data converter using Pydantic JSON conversion."""
84+
85+
86+ def _unwrap_restricted_fields (
87+ model : Type [BaseModel ],
88+ ) -> Type [BaseModel ]:
89+ fields = {
90+ name : (RestrictionContext .unwrap_if_proxied (f .annotation ), f )
91+ for name , f in model .model_fields .items ()
92+ }
93+ return create_model (model .__name__ , ** fields ) # type: ignore
0 commit comments