Skip to content

Commit ef79505

Browse files
committed
Copy of pydantic v1 converter from samples-python
1 parent 51f4b66 commit ef79505

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import json
2+
from typing import Any, Optional
3+
4+
from pydantic.json import pydantic_encoder
5+
from temporalio.api.common.v1 import Payload
6+
from temporalio.converter import (
7+
CompositePayloadConverter,
8+
DataConverter,
9+
DefaultPayloadConverter,
10+
JSONPlainPayloadConverter,
11+
)
12+
13+
14+
class PydanticJSONPayloadConverter(JSONPlainPayloadConverter):
15+
"""Pydantic JSON payload converter.
16+
17+
This extends the :py:class:`JSONPlainPayloadConverter` to override
18+
:py:meth:`to_payload` using the Pydantic encoder.
19+
"""
20+
21+
def to_payload(self, value: Any) -> Optional[Payload]:
22+
"""Convert all values with Pydantic encoder or fail.
23+
24+
Like the base class, we fail if we cannot convert. This payload
25+
converter is expected to be the last in the chain, so it can fail if
26+
unable to convert.
27+
"""
28+
# We let JSON conversion errors be thrown to caller
29+
return Payload(
30+
metadata={"encoding": self.encoding.encode()},
31+
data=json.dumps(
32+
value, separators=(",", ":"), sort_keys=True, default=pydantic_encoder
33+
).encode(),
34+
)
35+
36+
37+
class PydanticPayloadConverter(CompositePayloadConverter):
38+
"""Payload converter that replaces Temporal JSON conversion with Pydantic
39+
JSON conversion.
40+
"""
41+
42+
def __init__(self) -> None:
43+
super().__init__(
44+
*(
45+
(
46+
c
47+
if not isinstance(c, JSONPlainPayloadConverter)
48+
else PydanticJSONPayloadConverter()
49+
)
50+
for c in DefaultPayloadConverter.default_encoding_payload_converters
51+
)
52+
)
53+
54+
55+
pydantic_data_converter = DataConverter(
56+
payload_converter_class=PydanticPayloadConverter
57+
)
58+
"""Data converter using Pydantic JSON conversion."""

tests/contrib/test_pydantic.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import dataclasses
2+
import uuid
3+
from datetime import datetime, timedelta
4+
from ipaddress import IPv4Address
5+
from typing import List
6+
7+
from pydantic import BaseModel
8+
9+
from temporalio import activity, workflow
10+
from temporalio.client import Client
11+
from temporalio.contrib.pydantic.converter import pydantic_data_converter
12+
from temporalio.worker import Worker
13+
from temporalio.worker.workflow_sandbox import (
14+
SandboxedWorkflowRunner,
15+
SandboxRestrictions,
16+
)
17+
18+
19+
class MyPydanticModel(BaseModel):
20+
some_ip: IPv4Address
21+
some_date: datetime
22+
23+
24+
@activity.defn
25+
async def my_activity(models: List[MyPydanticModel]) -> List[MyPydanticModel]:
26+
activity.logger.info("Got models in activity: %s" % models)
27+
return models
28+
29+
30+
@workflow.defn
31+
class MyWorkflow:
32+
@workflow.run
33+
async def run(self, models: List[MyPydanticModel]) -> List[MyPydanticModel]:
34+
workflow.logger.info("Got models in workflow: %s" % models)
35+
return await workflow.execute_activity(
36+
my_activity, models, start_to_close_timeout=timedelta(minutes=1)
37+
)
38+
39+
40+
# Due to known issues with Pydantic's use of issubclass and our inability to
41+
# override the check in sandbox, Pydantic will think datetime is actually date
42+
# in the sandbox. At the expense of protecting against datetime.now() use in
43+
# workflows, we're going to remove datetime module restrictions. See sdk-python
44+
# README's discussion of known sandbox issues for more details.
45+
def new_sandbox_runner() -> SandboxedWorkflowRunner:
46+
# TODO(cretz): Use with_child_unrestricted when https://github.com/temporalio/sdk-python/issues/254
47+
# is fixed and released
48+
invalid_module_member_children = dict(
49+
SandboxRestrictions.invalid_module_members_default.children
50+
)
51+
del invalid_module_member_children["datetime"]
52+
return SandboxedWorkflowRunner(
53+
restrictions=dataclasses.replace(
54+
SandboxRestrictions.default,
55+
invalid_module_members=dataclasses.replace(
56+
SandboxRestrictions.invalid_module_members_default,
57+
children=invalid_module_member_children,
58+
),
59+
)
60+
)
61+
62+
63+
async def test_workflow_with_pydantic_model(client: Client):
64+
# Replace data converter in client
65+
new_config = client.config()
66+
new_config["data_converter"] = pydantic_data_converter
67+
client = Client(**new_config)
68+
task_queue_name = str(uuid.uuid4())
69+
70+
orig_models = [
71+
MyPydanticModel(
72+
some_ip=IPv4Address("127.0.0.1"),
73+
some_date=datetime(2000, 1, 2, 3, 4, 5),
74+
),
75+
MyPydanticModel(
76+
some_ip=IPv4Address("127.0.0.2"),
77+
some_date=datetime(2001, 2, 3, 4, 5, 6),
78+
),
79+
]
80+
81+
async with Worker(
82+
client,
83+
task_queue=task_queue_name,
84+
workflows=[MyWorkflow],
85+
activities=[my_activity],
86+
workflow_runner=new_sandbox_runner(),
87+
):
88+
result = await client.execute_workflow(
89+
MyWorkflow.run,
90+
orig_models,
91+
id=str(uuid.uuid4()),
92+
task_queue=task_queue_name,
93+
)
94+
assert orig_models == result

0 commit comments

Comments
 (0)