Skip to content

Commit 2342a4e

Browse files
committed
Cleanup
1 parent 564cb90 commit 2342a4e

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

temporalio/contrib/pydantic.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pydantic_core import to_jsonable_python
2424
except ImportError:
2525
# pydantic v1
26-
from pydantic.json import pydantic_encoder as to_jsonable_python
26+
from pydantic.json import pydantic_encoder as to_jsonable_python # type: ignore
2727

2828
import temporalio.workflow
2929
from temporalio.converter import (
@@ -36,9 +36,16 @@
3636
)
3737
from temporalio.worker.workflow_sandbox._restrictions import RestrictionContext
3838

39+
# Note that in addition to the implementation in this module, _RestrictedProxy
40+
# implements __get_pydantic_core_schema__ so that pydantic unwraps proxied types
41+
# when determining the schema.
42+
3943

4044
class PydanticModelTypeConverter(JSONTypeConverter):
45+
"""Type converter for pydantic model instances."""
46+
4147
def to_typed_value(self, hint: Type, value: Any) -> Any:
48+
"""Convert dict value to pydantic model instance of the specified type"""
4249
if not inspect.isclass(hint) or not issubclass(hint, pydantic.BaseModel):
4350
return JSONTypeConverter.Unhandled
4451
model = hint
@@ -67,20 +74,29 @@ def to_typed_value(self, hint: Type, value: Any) -> Any:
6774

6875

6976
class PydanticJSONEncoder(AdvancedJSONEncoder):
77+
"""JSON encoder for python objects containing pydantic model instances."""
78+
7079
def default(self, o: Any) -> Any:
80+
"""Convert object to jsonable python.
81+
82+
See :py:meth:`json.JSONEncoder.default`.
83+
"""
7184
if isinstance(o, pydantic.BaseModel):
7285
return to_jsonable_python(o)
7386
return super().default(o)
7487

7588

7689
class PydanticPayloadConverter(CompositePayloadConverter):
77-
"""Pydantic payload converter.
90+
"""Payload converter for payloads containing pydantic model instances.
7891
79-
Payload converter that replaces the default JSON conversion with Pydantic
80-
JSON conversion.
92+
JSON conversion is replaced with a converter that uses
93+
:py:class:`PydanticJSONEncoder` to convert the python object to JSON, and
94+
:py:class:`PydanticModelTypeConverter` to convert raw python values to
95+
pydantic model instances.
8196
"""
8297

8398
def __init__(self) -> None:
99+
"""Initialize object"""
84100
json_payload_converter = JSONPlainPayloadConverter(
85101
encoder=PydanticJSONEncoder,
86102
custom_type_converters=[PydanticModelTypeConverter()],
@@ -98,7 +114,7 @@ def __init__(self) -> None:
98114
pydantic_data_converter = DataConverter(
99115
payload_converter_class=PydanticPayloadConverter
100116
)
101-
"""Data converter for Pydantic models.
117+
"""Data converter for payloads containing pydantic model instances.
102118
103-
To use, pass this as the ``data_converter`` argument to :py:class:`temporalio.client.Client`
119+
To use, pass as the ``data_converter`` argument of :py:class:`temporalio.client.Client`
104120
"""

temporalio/worker/workflow_sandbox/_restrictions.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -951,14 +951,6 @@ def _is_restrictable(v: Any) -> bool:
951951

952952

953953
class _RestrictedProxy:
954-
@classmethod
955-
def __get_pydantic_core_schema__(
956-
cls, source_type: Any, handler: GetCoreSchemaHandler
957-
) -> CoreSchema:
958-
return core_schema.no_info_after_validator_function(
959-
cls, handler(RestrictionContext.unwrap_if_proxied(source_type))
960-
)
961-
962954
def __init__(self, *args, **kwargs) -> None:
963955
# When we instantiate this class, we have the signature of:
964956
# __init__(
@@ -1033,6 +1025,15 @@ def __getitem__(self, key: Any) -> Any:
10331025
)
10341026
return ret
10351027

1028+
# Instruct pydantic to use the proxied type when determining the schema
1029+
@classmethod
1030+
def __get_pydantic_core_schema__(
1031+
cls, source_type: Any, handler: GetCoreSchemaHandler
1032+
) -> CoreSchema:
1033+
return core_schema.no_info_after_validator_function(
1034+
cls, handler(RestrictionContext.unwrap_if_proxied(source_type))
1035+
)
1036+
10361037
__doc__ = _RestrictedProxyLookup( # type: ignore
10371038
class_value=__doc__, fallback_func=lambda self: type(self).__doc__, is_attr=True
10381039
)

0 commit comments

Comments
 (0)