Skip to content

Commit fa5a695

Browse files
committed
Organize tests
1 parent f0b9779 commit fa5a695

File tree

1 file changed

+254
-16
lines changed

1 file changed

+254
-16
lines changed

tests/contrib/test_pydantic.py

Lines changed: 254 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,20 @@
22
import uuid
33
from datetime import date, datetime, timedelta
44
from ipaddress import IPv4Address
5-
from typing import Annotated, Any, List, Sequence, Tuple, TypeVar, Union, get_type_hints
5+
from typing import (
6+
Annotated,
7+
Any,
8+
Dict,
9+
Generic,
10+
List,
11+
Optional,
12+
Sequence,
13+
Set,
14+
Tuple,
15+
TypeVar,
16+
Union,
17+
get_type_hints,
18+
)
619

720
from annotated_types import Len
821
from pydantic import BaseModel, Field, WithJsonSchema
@@ -16,6 +29,205 @@
1629
ShortSequence = Annotated[SequenceType, Len(max_length=2)]
1730

1831

32+
class BasicTypesModel(BaseModel):
33+
int_field: int
34+
float_field: float
35+
str_field: str
36+
bool_field: bool
37+
bytes_field: bytes
38+
none_field: None
39+
40+
def _check_instance(self):
41+
assert isinstance(self.int_field, int)
42+
assert isinstance(self.float_field, float)
43+
assert isinstance(self.str_field, str)
44+
assert isinstance(self.bool_field, bool)
45+
assert isinstance(self.bytes_field, bytes)
46+
assert self.none_field is None
47+
assert self.int_field == 42
48+
assert self.float_field == 3.14
49+
assert self.str_field == "hello"
50+
assert self.bool_field is True
51+
assert self.bytes_field == b"world"
52+
53+
54+
def make_basic_types_object() -> BasicTypesModel:
55+
return BasicTypesModel(
56+
int_field=42,
57+
float_field=3.14,
58+
str_field="hello",
59+
bool_field=True,
60+
bytes_field=b"world",
61+
none_field=None,
62+
)
63+
64+
65+
class ComplexTypesModel(BaseModel):
66+
list_field: List[str]
67+
dict_field: Dict[str, int]
68+
set_field: Set[int]
69+
tuple_field: Tuple[str, int]
70+
union_field: Union[str, int]
71+
optional_field: Optional[str]
72+
73+
def _check_instance(self):
74+
assert isinstance(self.list_field, list)
75+
assert isinstance(self.dict_field, dict)
76+
assert isinstance(self.set_field, set)
77+
assert isinstance(self.tuple_field, tuple)
78+
assert isinstance(self.union_field, str)
79+
assert isinstance(self.optional_field, str)
80+
assert self.list_field == ["a", "b", "c"]
81+
assert self.dict_field == {"x": 1, "y": 2}
82+
assert self.set_field == {1, 2, 3}
83+
assert self.tuple_field == ("hello", 42)
84+
assert self.union_field == "string_or_int"
85+
assert self.optional_field == "present"
86+
87+
88+
def make_complex_types_object() -> ComplexTypesModel:
89+
return ComplexTypesModel(
90+
list_field=["a", "b", "c"],
91+
dict_field={"x": 1, "y": 2},
92+
set_field={1, 2, 3},
93+
tuple_field=("hello", 42),
94+
union_field="string_or_int",
95+
optional_field="present",
96+
)
97+
98+
99+
class SpecialTypesModel(BaseModel):
100+
datetime_field: datetime
101+
date_field: date
102+
timedelta_field: timedelta
103+
# path_field: Path
104+
uuid_field: uuid.UUID
105+
ip_field: IPv4Address
106+
107+
def _check_instance(self):
108+
assert isinstance(self.datetime_field, datetime)
109+
assert isinstance(self.date_field, date)
110+
assert isinstance(self.timedelta_field, timedelta)
111+
# assert isinstance(self.path_field, Path)
112+
assert isinstance(self.uuid_field, uuid.UUID)
113+
assert isinstance(self.ip_field, IPv4Address)
114+
assert self.datetime_field == datetime(2000, 1, 2, 3, 4, 5)
115+
assert self.date_field == date(2000, 1, 2)
116+
assert self.timedelta_field == timedelta(days=1, hours=2)
117+
# assert self.path_field == Path("test/path")
118+
assert self.uuid_field == uuid.UUID("12345678-1234-5678-1234-567812345678")
119+
assert self.ip_field == IPv4Address("127.0.0.1")
120+
121+
122+
def make_special_types_object() -> SpecialTypesModel:
123+
return SpecialTypesModel(
124+
datetime_field=datetime(2000, 1, 2, 3, 4, 5),
125+
date_field=date(2000, 1, 2),
126+
timedelta_field=timedelta(days=1, hours=2),
127+
# path_field=Path("test/path"),
128+
uuid_field=uuid.UUID("12345678-1234-5678-1234-567812345678"),
129+
ip_field=IPv4Address("127.0.0.1"),
130+
)
131+
132+
133+
class ChildModel(BaseModel):
134+
name: str
135+
value: int
136+
137+
138+
class ParentModel(BaseModel):
139+
child: ChildModel
140+
children: List[ChildModel]
141+
142+
def _check_instance(self):
143+
assert isinstance(self.child, ChildModel)
144+
assert isinstance(self.children, list)
145+
assert all(isinstance(child, ChildModel) for child in self.children)
146+
assert self.child.name == "child1"
147+
assert self.child.value == 1
148+
assert len(self.children) == 2
149+
assert self.children[0].name == "child2"
150+
assert self.children[0].value == 2
151+
assert self.children[1].name == "child3"
152+
assert self.children[1].value == 3
153+
154+
155+
def make_nested_object() -> ParentModel:
156+
return ParentModel(
157+
child=ChildModel(name="child1", value=1),
158+
children=[
159+
ChildModel(name="child2", value=2),
160+
ChildModel(name="child3", value=3),
161+
],
162+
)
163+
164+
165+
class FieldFeaturesModel(BaseModel):
166+
field_with_default: str = "default"
167+
field_with_factory: datetime = Field(
168+
default_factory=lambda: datetime(2000, 1, 2, 3, 4, 5)
169+
)
170+
field_with_constraints: int = Field(gt=0, lt=100)
171+
field_with_alias: str = Field(alias="different_name")
172+
173+
def _check_instance(self):
174+
assert isinstance(self.field_with_default, str)
175+
assert isinstance(self.field_with_factory, datetime)
176+
assert isinstance(self.field_with_constraints, int)
177+
assert isinstance(self.field_with_alias, str)
178+
assert self.field_with_default == "default"
179+
assert 0 < self.field_with_constraints < 100
180+
assert self.field_with_alias == "aliased_value"
181+
182+
183+
def make_field_features_object() -> FieldFeaturesModel:
184+
return FieldFeaturesModel(
185+
field_with_constraints=50,
186+
different_name="aliased_value",
187+
)
188+
189+
190+
class AnnotatedFieldsModel(BaseModel):
191+
max_length_str: Annotated[str, Len(max_length=10)]
192+
custom_json: Annotated[Dict[str, Any], WithJsonSchema({"extra": "data"})]
193+
194+
def _check_instance(self):
195+
assert isinstance(self.max_length_str, str)
196+
assert isinstance(self.custom_json, dict)
197+
assert len(self.max_length_str) <= 10
198+
assert self.max_length_str == "short"
199+
assert self.custom_json == {"key": "value"}
200+
201+
202+
def make_annotated_fields_object() -> AnnotatedFieldsModel:
203+
return AnnotatedFieldsModel(
204+
max_length_str="short",
205+
custom_json={"key": "value"},
206+
)
207+
208+
209+
T = TypeVar("T")
210+
211+
212+
class GenericModel(BaseModel, Generic[T]):
213+
value: T
214+
values: List[T]
215+
216+
def _check_instance(self):
217+
assert isinstance(self.value, str)
218+
assert isinstance(self.values, list)
219+
assert all(isinstance(v, str) for v in self.values)
220+
assert self.value == "single"
221+
assert self.values == ["multiple", "values"]
222+
223+
224+
def make_generic_string_object() -> GenericModel[str]:
225+
return GenericModel[str](
226+
value="single",
227+
values=["multiple", "values"],
228+
)
229+
230+
19231
class PydanticModel(BaseModel):
20232
ip_field: IPv4Address
21233
string_field_assigned_field: str = Field()
@@ -143,6 +355,13 @@ def _check_instance(self):
143355

144356

145357
PydanticModels = Union[
358+
BasicTypesModel,
359+
ComplexTypesModel,
360+
SpecialTypesModel,
361+
ParentModel,
362+
FieldFeaturesModel,
363+
AnnotatedFieldsModel,
364+
GenericModel,
146365
PydanticModel,
147366
PydanticDatetimeModel,
148367
PydanticDateModel,
@@ -166,7 +385,7 @@ def _assert_timedelta_validity(td: timedelta):
166385

167386

168387
def make_homogeneous_list_of_pydantic_objects() -> List[PydanticModel]:
169-
return [
388+
objects = [
170389
PydanticModel(
171390
ip_field=IPv4Address("127.0.0.1"),
172391
string_field_assigned_field="my-string",
@@ -175,10 +394,20 @@ def make_homogeneous_list_of_pydantic_objects() -> List[PydanticModel]:
175394
union_field="my-string",
176395
),
177396
]
178-
179-
180-
def make_heterogenous_list_of_pydantic_objects() -> List[PydanticModels]:
181-
return [
397+
for o in objects:
398+
o._check_instance()
399+
return objects
400+
401+
402+
def make_heterogeneous_list_of_pydantic_objects() -> List[PydanticModels]:
403+
objects = [
404+
make_basic_types_object(),
405+
make_complex_types_object(),
406+
make_special_types_object(),
407+
make_nested_object(),
408+
make_field_features_object(),
409+
make_annotated_fields_object(),
410+
make_generic_string_object(),
182411
PydanticModel(
183412
ip_field=IPv4Address("127.0.0.1"),
184413
string_field_assigned_field="my-string",
@@ -220,6 +449,9 @@ def make_heterogenous_list_of_pydantic_objects() -> List[PydanticModels]:
220449
],
221450
),
222451
]
452+
for o in objects:
453+
o._check_instance()
454+
return objects
223455

224456

225457
@activity.defn
@@ -237,7 +469,7 @@ async def heterogeneous_list_of_pydantic_models_activity(
237469

238470

239471
@workflow.defn
240-
class HomogenousListOfPydanticObjectsWorkflow:
472+
class HomogeneousListOfPydanticObjectsWorkflow:
241473
@workflow.run
242474
async def run(self, models: List[PydanticModel]) -> List[PydanticModel]:
243475
return await workflow.execute_activity(
@@ -248,7 +480,7 @@ async def run(self, models: List[PydanticModel]) -> List[PydanticModel]:
248480

249481

250482
@workflow.defn
251-
class HeterogenousListOfPydanticObjectsWorkflow:
483+
class HeterogeneousListOfPydanticObjectsWorkflow:
252484
@workflow.run
253485
async def run(self, models: List[PydanticModels]) -> List[PydanticModels]:
254486
return await workflow.execute_activity(
@@ -269,39 +501,43 @@ async def test_homogeneous_list_of_pydantic_objects(client: Client):
269501
async with Worker(
270502
client,
271503
task_queue=task_queue_name,
272-
workflows=[HomogenousListOfPydanticObjectsWorkflow],
504+
workflows=[HomogeneousListOfPydanticObjectsWorkflow],
273505
activities=[homogeneous_list_of_pydantic_models_activity],
274506
):
275507
round_tripped_pydantic_objects = await client.execute_workflow(
276-
HomogenousListOfPydanticObjectsWorkflow.run,
508+
HomogeneousListOfPydanticObjectsWorkflow.run,
277509
orig_pydantic_objects,
278510
id=str(uuid.uuid4()),
279511
task_queue=task_queue_name,
280512
)
281513
assert orig_pydantic_objects == round_tripped_pydantic_objects
514+
for o in round_tripped_pydantic_objects:
515+
o._check_instance()
282516

283517

284-
async def test_heterogenous_list_of_pydantic_objects(client: Client):
518+
async def test_heterogeneous_list_of_pydantic_objects(client: Client):
285519
new_config = client.config()
286520
new_config["data_converter"] = pydantic_data_converter
287521
client = Client(**new_config)
288522
task_queue_name = str(uuid.uuid4())
289523

290-
orig_pydantic_objects = make_heterogenous_list_of_pydantic_objects()
524+
orig_pydantic_objects = make_heterogeneous_list_of_pydantic_objects()
291525

292526
async with Worker(
293527
client,
294528
task_queue=task_queue_name,
295-
workflows=[HeterogenousListOfPydanticObjectsWorkflow],
529+
workflows=[HeterogeneousListOfPydanticObjectsWorkflow],
296530
activities=[heterogeneous_list_of_pydantic_models_activity],
297531
):
298532
round_tripped_pydantic_objects = await client.execute_workflow(
299-
HeterogenousListOfPydanticObjectsWorkflow.run,
533+
HeterogeneousListOfPydanticObjectsWorkflow.run,
300534
orig_pydantic_objects,
301535
id=str(uuid.uuid4()),
302536
task_queue=task_queue_name,
303537
)
304538
assert orig_pydantic_objects == round_tripped_pydantic_objects
539+
for o in round_tripped_pydantic_objects:
540+
o._check_instance()
305541

306542

307543
@dataclasses.dataclass
@@ -342,7 +578,7 @@ async def test_mixed_collection_types(client: Client):
342578
task_queue_name = str(uuid.uuid4())
343579

344580
orig_dataclass_objects = make_dataclass_objects()
345-
orig_pydantic_objects = make_heterogenous_list_of_pydantic_objects()
581+
orig_pydantic_objects = make_heterogeneous_list_of_pydantic_objects()
346582

347583
async with Worker(
348584
client,
@@ -361,13 +597,15 @@ async def test_mixed_collection_types(client: Client):
361597
)
362598
assert orig_dataclass_objects == round_tripped_dataclass_objects
363599
assert orig_pydantic_objects == round_tripped_pydantic_objects
600+
for o in round_tripped_pydantic_objects:
601+
o._check_instance()
364602

365603

366604
@workflow.defn
367605
class PydanticModelUsageWorkflow:
368606
@workflow.run
369607
async def run(self) -> None:
370-
for o in make_heterogenous_list_of_pydantic_objects():
608+
for o in make_heterogeneous_list_of_pydantic_objects():
371609
o._check_instance()
372610

373611

0 commit comments

Comments
 (0)