Skip to content

Commit 3ac2a29

Browse files
committed
Add test of mixed type inputs
1 parent efcf011 commit 3ac2a29

File tree

1 file changed

+89
-30
lines changed

1 file changed

+89
-30
lines changed

tests/contrib/test_pydantic.py

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import dataclasses
12
import uuid
23
from datetime import datetime, timedelta
34
from ipaddress import IPv4Address
4-
from typing import Annotated, Any, List, Sequence, TypeVar
5+
from typing import Annotated, Any, List, Sequence, Tuple, TypeVar
56

67
from annotated_types import Len
78
from pydantic import BaseModel, Field, WithJsonSchema
@@ -35,29 +36,8 @@ class MyPydanticModel(BaseModel):
3536
datetime_short_sequence: ShortSequence[List[datetime]]
3637

3738

38-
@activity.defn
39-
async def my_activity(models: List[MyPydanticModel]) -> List[MyPydanticModel]:
40-
activity.logger.info("Got models in activity: %s" % models)
41-
return models
42-
43-
44-
@workflow.defn
45-
class MyWorkflow:
46-
@workflow.run
47-
async def run(self, models: List[MyPydanticModel]) -> List[MyPydanticModel]:
48-
workflow.logger.info("Got models in workflow: %s" % models)
49-
return await workflow.execute_activity(
50-
my_activity, models, start_to_close_timeout=timedelta(minutes=1)
51-
)
52-
53-
54-
async def test_field_conversion(client: Client):
55-
new_config = client.config()
56-
new_config["data_converter"] = pydantic_data_converter
57-
client = Client(**new_config)
58-
task_queue_name = str(uuid.uuid4())
59-
60-
orig_models = [
39+
def make_pydantic_objects() -> List[MyPydanticModel]:
40+
return [
6141
MyPydanticModel(
6242
ip_field=IPv4Address("127.0.0.1"),
6343
datetime_field=datetime(2000, 1, 2, 3, 4, 5),
@@ -94,16 +74,95 @@ async def test_field_conversion(client: Client):
9474
),
9575
]
9676

77+
78+
@activity.defn
79+
async def list_of_pydantic_models_activity(
80+
models: List[MyPydanticModel],
81+
) -> List[MyPydanticModel]:
82+
return models
83+
84+
85+
@workflow.defn
86+
class ListOfPydanticObjectsWorkflow:
87+
@workflow.run
88+
async def run(self, models: List[MyPydanticModel]) -> List[MyPydanticModel]:
89+
return await workflow.execute_activity(
90+
list_of_pydantic_models_activity,
91+
models,
92+
start_to_close_timeout=timedelta(minutes=1),
93+
)
94+
95+
96+
async def test_field_conversion(client: Client):
97+
new_config = client.config()
98+
new_config["data_converter"] = pydantic_data_converter
99+
client = Client(**new_config)
100+
task_queue_name = str(uuid.uuid4())
101+
102+
orig_pydantic_objects = make_pydantic_objects()
103+
104+
async with Worker(
105+
client,
106+
task_queue=task_queue_name,
107+
workflows=[ListOfPydanticObjectsWorkflow],
108+
activities=[list_of_pydantic_models_activity],
109+
):
110+
round_tripped_pydantic_objects = await client.execute_workflow(
111+
ListOfPydanticObjectsWorkflow.run,
112+
orig_pydantic_objects,
113+
id=str(uuid.uuid4()),
114+
task_queue=task_queue_name,
115+
)
116+
assert orig_pydantic_objects == round_tripped_pydantic_objects
117+
118+
119+
@dataclasses.dataclass
120+
class MyDataClass:
121+
int_field: int
122+
123+
124+
def make_dataclass_objects() -> List[MyDataClass]:
125+
return [MyDataClass(int_field=7)]
126+
127+
128+
@workflow.defn
129+
class MixedCollectionTypesWorkflow:
130+
@workflow.run
131+
async def run(
132+
self, input: Tuple[List[MyDataClass], List[MyPydanticModel]]
133+
) -> Tuple[List[MyDataClass], List[MyPydanticModel]]:
134+
data_classes, pydantic_objects = input
135+
pydantic_objects = await workflow.execute_activity(
136+
list_of_pydantic_models_activity,
137+
pydantic_objects,
138+
start_to_close_timeout=timedelta(minutes=1),
139+
)
140+
return data_classes, pydantic_objects
141+
142+
143+
async def test_mixed_collection_types(client: Client):
144+
new_config = client.config()
145+
new_config["data_converter"] = pydantic_data_converter
146+
client = Client(**new_config)
147+
task_queue_name = str(uuid.uuid4())
148+
149+
orig_dataclass_objects = make_dataclass_objects()
150+
orig_pydantic_objects = make_pydantic_objects()
151+
97152
async with Worker(
98153
client,
99154
task_queue=task_queue_name,
100-
workflows=[MyWorkflow],
101-
activities=[my_activity],
155+
workflows=[MixedCollectionTypesWorkflow],
156+
activities=[list_of_pydantic_models_activity],
102157
):
103-
result = await client.execute_workflow(
104-
MyWorkflow.run,
105-
orig_models,
158+
(
159+
round_tripped_dataclass_objects,
160+
round_tripped_pydantic_objects,
161+
) = await client.execute_workflow(
162+
MixedCollectionTypesWorkflow.run,
163+
(orig_dataclass_objects, orig_pydantic_objects),
106164
id=str(uuid.uuid4()),
107165
task_queue=task_queue_name,
108166
)
109-
assert orig_models == result
167+
assert orig_dataclass_objects == round_tripped_dataclass_objects
168+
assert orig_pydantic_objects == round_tripped_pydantic_objects

0 commit comments

Comments
 (0)