22import uuid
33from datetime import date , datetime , timedelta
44from ipaddress import IPv4Address
5- from typing import Annotated , Any , List , Sequence , Tuple , TypeVar , Union , get_type_hints
5+ from typing import (
6+ Annotated ,
7+ Any ,
8+ Dict ,
9+ Generic ,
10+ List ,
11+ Optional ,
12+ Sequence ,
13+ Set ,
14+ Tuple ,
15+ TypeVar ,
16+ Union ,
17+ get_type_hints ,
18+ )
619
720from annotated_types import Len
821from pydantic import BaseModel , Field , WithJsonSchema
1629ShortSequence = Annotated [SequenceType , Len (max_length = 2 )]
1730
1831
32+ class BasicTypesModel (BaseModel ):
33+ int_field : int
34+ float_field : float
35+ str_field : str
36+ bool_field : bool
37+ bytes_field : bytes
38+ none_field : None
39+
40+ def _check_instance (self ):
41+ assert isinstance (self .int_field , int )
42+ assert isinstance (self .float_field , float )
43+ assert isinstance (self .str_field , str )
44+ assert isinstance (self .bool_field , bool )
45+ assert isinstance (self .bytes_field , bytes )
46+ assert self .none_field is None
47+ assert self .int_field == 42
48+ assert self .float_field == 3.14
49+ assert self .str_field == "hello"
50+ assert self .bool_field is True
51+ assert self .bytes_field == b"world"
52+
53+
54+ def make_basic_types_object () -> BasicTypesModel :
55+ return BasicTypesModel (
56+ int_field = 42 ,
57+ float_field = 3.14 ,
58+ str_field = "hello" ,
59+ bool_field = True ,
60+ bytes_field = b"world" ,
61+ none_field = None ,
62+ )
63+
64+
65+ class ComplexTypesModel (BaseModel ):
66+ list_field : List [str ]
67+ dict_field : Dict [str , int ]
68+ set_field : Set [int ]
69+ tuple_field : Tuple [str , int ]
70+ union_field : Union [str , int ]
71+ optional_field : Optional [str ]
72+
73+ def _check_instance (self ):
74+ assert isinstance (self .list_field , list )
75+ assert isinstance (self .dict_field , dict )
76+ assert isinstance (self .set_field , set )
77+ assert isinstance (self .tuple_field , tuple )
78+ assert isinstance (self .union_field , str )
79+ assert isinstance (self .optional_field , str )
80+ assert self .list_field == ["a" , "b" , "c" ]
81+ assert self .dict_field == {"x" : 1 , "y" : 2 }
82+ assert self .set_field == {1 , 2 , 3 }
83+ assert self .tuple_field == ("hello" , 42 )
84+ assert self .union_field == "string_or_int"
85+ assert self .optional_field == "present"
86+
87+
88+ def make_complex_types_object () -> ComplexTypesModel :
89+ return ComplexTypesModel (
90+ list_field = ["a" , "b" , "c" ],
91+ dict_field = {"x" : 1 , "y" : 2 },
92+ set_field = {1 , 2 , 3 },
93+ tuple_field = ("hello" , 42 ),
94+ union_field = "string_or_int" ,
95+ optional_field = "present" ,
96+ )
97+
98+
99+ class SpecialTypesModel (BaseModel ):
100+ datetime_field : datetime
101+ date_field : date
102+ timedelta_field : timedelta
103+ # path_field: Path
104+ uuid_field : uuid .UUID
105+ ip_field : IPv4Address
106+
107+ def _check_instance (self ):
108+ assert isinstance (self .datetime_field , datetime )
109+ assert isinstance (self .date_field , date )
110+ assert isinstance (self .timedelta_field , timedelta )
111+ # assert isinstance(self.path_field, Path)
112+ assert isinstance (self .uuid_field , uuid .UUID )
113+ assert isinstance (self .ip_field , IPv4Address )
114+ assert self .datetime_field == datetime (2000 , 1 , 2 , 3 , 4 , 5 )
115+ assert self .date_field == date (2000 , 1 , 2 )
116+ assert self .timedelta_field == timedelta (days = 1 , hours = 2 )
117+ # assert self.path_field == Path("test/path")
118+ assert self .uuid_field == uuid .UUID ("12345678-1234-5678-1234-567812345678" )
119+ assert self .ip_field == IPv4Address ("127.0.0.1" )
120+
121+
122+ def make_special_types_object () -> SpecialTypesModel :
123+ return SpecialTypesModel (
124+ datetime_field = datetime (2000 , 1 , 2 , 3 , 4 , 5 ),
125+ date_field = date (2000 , 1 , 2 ),
126+ timedelta_field = timedelta (days = 1 , hours = 2 ),
127+ # path_field=Path("test/path"),
128+ uuid_field = uuid .UUID ("12345678-1234-5678-1234-567812345678" ),
129+ ip_field = IPv4Address ("127.0.0.1" ),
130+ )
131+
132+
133+ class ChildModel (BaseModel ):
134+ name : str
135+ value : int
136+
137+
138+ class ParentModel (BaseModel ):
139+ child : ChildModel
140+ children : List [ChildModel ]
141+
142+ def _check_instance (self ):
143+ assert isinstance (self .child , ChildModel )
144+ assert isinstance (self .children , list )
145+ assert all (isinstance (child , ChildModel ) for child in self .children )
146+ assert self .child .name == "child1"
147+ assert self .child .value == 1
148+ assert len (self .children ) == 2
149+ assert self .children [0 ].name == "child2"
150+ assert self .children [0 ].value == 2
151+ assert self .children [1 ].name == "child3"
152+ assert self .children [1 ].value == 3
153+
154+
155+ def make_nested_object () -> ParentModel :
156+ return ParentModel (
157+ child = ChildModel (name = "child1" , value = 1 ),
158+ children = [
159+ ChildModel (name = "child2" , value = 2 ),
160+ ChildModel (name = "child3" , value = 3 ),
161+ ],
162+ )
163+
164+
165+ class FieldFeaturesModel (BaseModel ):
166+ field_with_default : str = "default"
167+ field_with_factory : datetime = Field (
168+ default_factory = lambda : datetime (2000 , 1 , 2 , 3 , 4 , 5 )
169+ )
170+ field_with_constraints : int = Field (gt = 0 , lt = 100 )
171+ field_with_alias : str = Field (alias = "different_name" )
172+
173+ def _check_instance (self ):
174+ assert isinstance (self .field_with_default , str )
175+ assert isinstance (self .field_with_factory , datetime )
176+ assert isinstance (self .field_with_constraints , int )
177+ assert isinstance (self .field_with_alias , str )
178+ assert self .field_with_default == "default"
179+ assert 0 < self .field_with_constraints < 100
180+ assert self .field_with_alias == "aliased_value"
181+
182+
183+ def make_field_features_object () -> FieldFeaturesModel :
184+ return FieldFeaturesModel (
185+ field_with_constraints = 50 ,
186+ different_name = "aliased_value" ,
187+ )
188+
189+
190+ class AnnotatedFieldsModel (BaseModel ):
191+ max_length_str : Annotated [str , Len (max_length = 10 )]
192+ custom_json : Annotated [Dict [str , Any ], WithJsonSchema ({"extra" : "data" })]
193+
194+ def _check_instance (self ):
195+ assert isinstance (self .max_length_str , str )
196+ assert isinstance (self .custom_json , dict )
197+ assert len (self .max_length_str ) <= 10
198+ assert self .max_length_str == "short"
199+ assert self .custom_json == {"key" : "value" }
200+
201+
202+ def make_annotated_fields_object () -> AnnotatedFieldsModel :
203+ return AnnotatedFieldsModel (
204+ max_length_str = "short" ,
205+ custom_json = {"key" : "value" },
206+ )
207+
208+
209+ T = TypeVar ("T" )
210+
211+
212+ class GenericModel (BaseModel , Generic [T ]):
213+ value : T
214+ values : List [T ]
215+
216+ def _check_instance (self ):
217+ assert isinstance (self .value , str )
218+ assert isinstance (self .values , list )
219+ assert all (isinstance (v , str ) for v in self .values )
220+ assert self .value == "single"
221+ assert self .values == ["multiple" , "values" ]
222+
223+
224+ def make_generic_string_object () -> GenericModel [str ]:
225+ return GenericModel [str ](
226+ value = "single" ,
227+ values = ["multiple" , "values" ],
228+ )
229+
230+
19231class PydanticModel (BaseModel ):
20232 ip_field : IPv4Address
21233 string_field_assigned_field : str = Field ()
@@ -143,6 +355,13 @@ def _check_instance(self):
143355
144356
145357PydanticModels = Union [
358+ BasicTypesModel ,
359+ ComplexTypesModel ,
360+ SpecialTypesModel ,
361+ ParentModel ,
362+ FieldFeaturesModel ,
363+ AnnotatedFieldsModel ,
364+ GenericModel ,
146365 PydanticModel ,
147366 PydanticDatetimeModel ,
148367 PydanticDateModel ,
@@ -166,7 +385,7 @@ def _assert_timedelta_validity(td: timedelta):
166385
167386
168387def make_homogeneous_list_of_pydantic_objects () -> List [PydanticModel ]:
169- return [
388+ objects = [
170389 PydanticModel (
171390 ip_field = IPv4Address ("127.0.0.1" ),
172391 string_field_assigned_field = "my-string" ,
@@ -175,10 +394,20 @@ def make_homogeneous_list_of_pydantic_objects() -> List[PydanticModel]:
175394 union_field = "my-string" ,
176395 ),
177396 ]
178-
179-
180- def make_heterogenous_list_of_pydantic_objects () -> List [PydanticModels ]:
181- return [
397+ for o in objects :
398+ o ._check_instance ()
399+ return objects
400+
401+
402+ def make_heterogeneous_list_of_pydantic_objects () -> List [PydanticModels ]:
403+ objects = [
404+ make_basic_types_object (),
405+ make_complex_types_object (),
406+ make_special_types_object (),
407+ make_nested_object (),
408+ make_field_features_object (),
409+ make_annotated_fields_object (),
410+ make_generic_string_object (),
182411 PydanticModel (
183412 ip_field = IPv4Address ("127.0.0.1" ),
184413 string_field_assigned_field = "my-string" ,
@@ -220,6 +449,9 @@ def make_heterogenous_list_of_pydantic_objects() -> List[PydanticModels]:
220449 ],
221450 ),
222451 ]
452+ for o in objects :
453+ o ._check_instance ()
454+ return objects
223455
224456
225457@activity .defn
@@ -237,7 +469,7 @@ async def heterogeneous_list_of_pydantic_models_activity(
237469
238470
239471@workflow .defn
240- class HomogenousListOfPydanticObjectsWorkflow :
472+ class HomogeneousListOfPydanticObjectsWorkflow :
241473 @workflow .run
242474 async def run (self , models : List [PydanticModel ]) -> List [PydanticModel ]:
243475 return await workflow .execute_activity (
@@ -248,7 +480,7 @@ async def run(self, models: List[PydanticModel]) -> List[PydanticModel]:
248480
249481
250482@workflow .defn
251- class HeterogenousListOfPydanticObjectsWorkflow :
483+ class HeterogeneousListOfPydanticObjectsWorkflow :
252484 @workflow .run
253485 async def run (self , models : List [PydanticModels ]) -> List [PydanticModels ]:
254486 return await workflow .execute_activity (
@@ -269,39 +501,43 @@ async def test_homogeneous_list_of_pydantic_objects(client: Client):
269501 async with Worker (
270502 client ,
271503 task_queue = task_queue_name ,
272- workflows = [HomogenousListOfPydanticObjectsWorkflow ],
504+ workflows = [HomogeneousListOfPydanticObjectsWorkflow ],
273505 activities = [homogeneous_list_of_pydantic_models_activity ],
274506 ):
275507 round_tripped_pydantic_objects = await client .execute_workflow (
276- HomogenousListOfPydanticObjectsWorkflow .run ,
508+ HomogeneousListOfPydanticObjectsWorkflow .run ,
277509 orig_pydantic_objects ,
278510 id = str (uuid .uuid4 ()),
279511 task_queue = task_queue_name ,
280512 )
281513 assert orig_pydantic_objects == round_tripped_pydantic_objects
514+ for o in round_tripped_pydantic_objects :
515+ o ._check_instance ()
282516
283517
284- async def test_heterogenous_list_of_pydantic_objects (client : Client ):
518+ async def test_heterogeneous_list_of_pydantic_objects (client : Client ):
285519 new_config = client .config ()
286520 new_config ["data_converter" ] = pydantic_data_converter
287521 client = Client (** new_config )
288522 task_queue_name = str (uuid .uuid4 ())
289523
290- orig_pydantic_objects = make_heterogenous_list_of_pydantic_objects ()
524+ orig_pydantic_objects = make_heterogeneous_list_of_pydantic_objects ()
291525
292526 async with Worker (
293527 client ,
294528 task_queue = task_queue_name ,
295- workflows = [HeterogenousListOfPydanticObjectsWorkflow ],
529+ workflows = [HeterogeneousListOfPydanticObjectsWorkflow ],
296530 activities = [heterogeneous_list_of_pydantic_models_activity ],
297531 ):
298532 round_tripped_pydantic_objects = await client .execute_workflow (
299- HeterogenousListOfPydanticObjectsWorkflow .run ,
533+ HeterogeneousListOfPydanticObjectsWorkflow .run ,
300534 orig_pydantic_objects ,
301535 id = str (uuid .uuid4 ()),
302536 task_queue = task_queue_name ,
303537 )
304538 assert orig_pydantic_objects == round_tripped_pydantic_objects
539+ for o in round_tripped_pydantic_objects :
540+ o ._check_instance ()
305541
306542
307543@dataclasses .dataclass
@@ -342,7 +578,7 @@ async def test_mixed_collection_types(client: Client):
342578 task_queue_name = str (uuid .uuid4 ())
343579
344580 orig_dataclass_objects = make_dataclass_objects ()
345- orig_pydantic_objects = make_heterogenous_list_of_pydantic_objects ()
581+ orig_pydantic_objects = make_heterogeneous_list_of_pydantic_objects ()
346582
347583 async with Worker (
348584 client ,
@@ -361,13 +597,15 @@ async def test_mixed_collection_types(client: Client):
361597 )
362598 assert orig_dataclass_objects == round_tripped_dataclass_objects
363599 assert orig_pydantic_objects == round_tripped_pydantic_objects
600+ for o in round_tripped_pydantic_objects :
601+ o ._check_instance ()
364602
365603
366604@workflow .defn
367605class PydanticModelUsageWorkflow :
368606 @workflow .run
369607 async def run (self ) -> None :
370- for o in make_heterogenous_list_of_pydantic_objects ():
608+ for o in make_heterogeneous_list_of_pydantic_objects ():
371609 o ._check_instance ()
372610
373611
0 commit comments