1+ import inspect
12import json
2- from typing import Any , Optional
3+ from typing import (
4+ Any ,
5+ Optional ,
6+ Type ,
7+ )
38
9+ import pydantic
410from pydantic .json import pydantic_encoder
11+
12+ import temporalio .workflow
513from temporalio .api .common .v1 import Payload
614from temporalio .converter import (
715 CompositePayloadConverter ,
816 DataConverter ,
917 DefaultPayloadConverter ,
1018 JSONPlainPayloadConverter ,
19+ JSONTypeConverter ,
1120)
21+ from temporalio .worker .workflow_sandbox ._restrictions import RestrictionContext
1222
1323
1424class PydanticJSONPayloadConverter (JSONPlainPayloadConverter ):
1525 """Pydantic JSON payload converter.
1626
17- This extends the :py:class:`JSONPlainPayloadConverter` to override
18- :py:meth:`to_payload` using the Pydantic encoder.
27+ Extends :py:class:`JSONPlainPayloadConverter` to override :py:meth:`to_payload` using
28+ the Pydantic encoder. :py:meth:`from_payload` uses the parent implementation, with a
29+ custom type converter.
1930 """
2031
32+ def __init__ (self ) -> None :
33+ super ().__init__ (custom_type_converters = [PydanticModelTypeConverter ()])
34+
2135 def to_payload (self , value : Any ) -> Optional [Payload ]:
2236 """Convert all values with Pydantic encoder or fail.
2337
2438 Like the base class, we fail if we cannot convert. This payload
2539 converter is expected to be the last in the chain, so it can fail if
2640 unable to convert.
2741 """
28- # We let JSON conversion errors be thrown to caller
42+ # Let JSON conversion errors be thrown to caller
2943 return Payload (
3044 metadata = {"encoding" : self .encoding .encode ()},
3145 data = json .dumps (
@@ -34,17 +48,47 @@ def to_payload(self, value: Any) -> Optional[Payload]:
3448 )
3549
3650
51+ class PydanticModelTypeConverter (JSONTypeConverter ):
52+ def to_typed_value (self , hint : Type , value : Any ) -> Any :
53+ if not inspect .isclass (hint ) or not issubclass (hint , pydantic .BaseModel ):
54+ return JSONTypeConverter .Unhandled
55+ model = hint
56+ if not isinstance (value , dict ):
57+ raise TypeError (
58+ f"Cannot convert to { model } , value is { type (value )} not dict"
59+ )
60+ if temporalio .workflow .unsafe .in_sandbox ():
61+ # Unwrap proxied model field types so that Pydantic can call their constructors
62+ model = pydantic .create_model (
63+ model .__name__ ,
64+ ** { # type: ignore
65+ name : (RestrictionContext .unwrap_if_proxied (f .annotation ), f )
66+ for name , f in model .model_fields .items ()
67+ },
68+ )
69+ if hasattr (model , "model_validate" ):
70+ return model .model_validate (value )
71+ elif hasattr (model , "parse_obj" ):
72+ # Pydantic v1
73+ return model .parse_obj (value )
74+ else :
75+ raise ValueError (
76+ f"{ model } is a Pydantic model but does not have a `model_validate` or `parse_obj` method"
77+ )
78+
79+
3780class PydanticPayloadConverter (CompositePayloadConverter ):
3881 """Payload converter that replaces Temporal JSON conversion with Pydantic
3982 JSON conversion.
4083 """
4184
4285 def __init__ (self ) -> None :
86+ json_payload_converter = PydanticJSONPayloadConverter ()
4387 super ().__init__ (
4488 * (
4589 c
4690 if not isinstance (c , JSONPlainPayloadConverter )
47- else PydanticJSONPayloadConverter ()
91+ else json_payload_converter
4892 for c in DefaultPayloadConverter .default_encoding_payload_converters
4993 )
5094 )
0 commit comments