@@ -741,13 +741,14 @@ async def test_clone_objects_in_sandbox(client: Client):
741741 o ._check_instance ()
742742
743743
744- @dataclasses .dataclass
744+ @dataclasses .dataclass ( order = True )
745745class MyDataClass :
746- int_field : int
746+ # The name int_field also occurs in StandardTypesModel and currently unions can match them up incorrectly.
747+ data_class_int_field : int
747748
748749
749750def make_dataclass_objects () -> List [MyDataClass ]:
750- return [MyDataClass (int_field = 7 )]
751+ return [MyDataClass (data_class_int_field = 7 )]
751752
752753
753754ComplexCustomType = Tuple [List [MyDataClass ], List [PydanticModels ]]
@@ -799,6 +800,73 @@ async def test_complex_custom_type(client: Client):
799800 o ._check_instance ()
800801
801802
803+ ComplexCustomUnionType = List [Union [MyDataClass , PydanticModels ]]
804+
805+
806+ @workflow .defn
807+ class ComplexCustomUnionTypeWorkflow :
808+ @workflow .run
809+ async def run (
810+ self ,
811+ input : ComplexCustomUnionType ,
812+ ) -> ComplexCustomUnionType :
813+ data_classes , pydantic_objects = [], []
814+ for o in input :
815+ if dataclasses .is_dataclass (o ):
816+ data_classes .append (o )
817+ elif isinstance (o , BaseModel ):
818+ pydantic_objects .append (o )
819+ else :
820+ raise TypeError (f"Unexpected type: { type (o )} " )
821+ pydantic_objects = await workflow .execute_activity (
822+ pydantic_models_activity ,
823+ pydantic_objects ,
824+ start_to_close_timeout = timedelta (minutes = 1 ),
825+ )
826+ return data_classes + pydantic_objects
827+
828+
829+ async def test_complex_custom_union_type (client : Client ):
830+ new_config = client .config ()
831+ new_config ["data_converter" ] = pydantic_data_converter
832+ client = Client (** new_config )
833+ task_queue_name = str (uuid .uuid4 ())
834+
835+ orig_dataclass_objects = make_dataclass_objects ()
836+ orig_pydantic_objects = make_list_of_pydantic_objects ()
837+ orig_objects = orig_dataclass_objects + orig_pydantic_objects
838+ import random
839+
840+ random .shuffle (orig_objects )
841+
842+ async with Worker (
843+ client ,
844+ task_queue = task_queue_name ,
845+ workflows = [ComplexCustomUnionTypeWorkflow ],
846+ activities = [pydantic_models_activity ],
847+ ):
848+ round_tripped_objects = await client .execute_workflow (
849+ ComplexCustomUnionTypeWorkflow .run ,
850+ orig_objects ,
851+ id = str (uuid .uuid4 ()),
852+ task_queue = task_queue_name ,
853+ )
854+ round_tripped_dataclass_objects , round_tripped_pydantic_objects = [], []
855+ for o in round_tripped_objects :
856+ if isinstance (o , MyDataClass ):
857+ round_tripped_dataclass_objects .append (o )
858+ elif isinstance (o , BaseModel ):
859+ round_tripped_pydantic_objects .append (o )
860+ else :
861+ raise TypeError (f"Unexpected type: { type (o )} " )
862+ assert sorted (orig_dataclass_objects ) == sorted (round_tripped_dataclass_objects )
863+ assert sorted (orig_pydantic_objects , key = lambda o : o .__class__ .__name__ ) == sorted (
864+ round_tripped_pydantic_objects , key = lambda o : o .__class__ .__name__
865+ )
866+ for o in round_tripped_pydantic_objects :
867+ o ._check_instance ()
868+
869+
802870@workflow .defn
803871class PydanticModelUsageWorkflow :
804872 @workflow .run
0 commit comments