Skip to content

Commit d51b0dc

Browse files
committed
Test complex union
1 parent 77d01eb commit d51b0dc

File tree

1 file changed

+71
-3
lines changed

1 file changed

+71
-3
lines changed

tests/contrib/test_pydantic.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -741,13 +741,14 @@ async def test_clone_objects_in_sandbox(client: Client):
741741
o._check_instance()
742742

743743

744-
@dataclasses.dataclass
744+
@dataclasses.dataclass(order=True)
745745
class MyDataClass:
746-
int_field: int
746+
# The name int_field also occurs in StandardTypesModel and currently unions can match them up incorrectly.
747+
data_class_int_field: int
747748

748749

749750
def make_dataclass_objects() -> List[MyDataClass]:
750-
return [MyDataClass(int_field=7)]
751+
return [MyDataClass(data_class_int_field=7)]
751752

752753

753754
ComplexCustomType = Tuple[List[MyDataClass], List[PydanticModels]]
@@ -799,6 +800,73 @@ async def test_complex_custom_type(client: Client):
799800
o._check_instance()
800801

801802

803+
ComplexCustomUnionType = List[Union[MyDataClass, PydanticModels]]
804+
805+
806+
@workflow.defn
807+
class ComplexCustomUnionTypeWorkflow:
808+
@workflow.run
809+
async def run(
810+
self,
811+
input: ComplexCustomUnionType,
812+
) -> ComplexCustomUnionType:
813+
data_classes, pydantic_objects = [], []
814+
for o in input:
815+
if dataclasses.is_dataclass(o):
816+
data_classes.append(o)
817+
elif isinstance(o, BaseModel):
818+
pydantic_objects.append(o)
819+
else:
820+
raise TypeError(f"Unexpected type: {type(o)}")
821+
pydantic_objects = await workflow.execute_activity(
822+
pydantic_models_activity,
823+
pydantic_objects,
824+
start_to_close_timeout=timedelta(minutes=1),
825+
)
826+
return data_classes + pydantic_objects
827+
828+
829+
async def test_complex_custom_union_type(client: Client):
830+
new_config = client.config()
831+
new_config["data_converter"] = pydantic_data_converter
832+
client = Client(**new_config)
833+
task_queue_name = str(uuid.uuid4())
834+
835+
orig_dataclass_objects = make_dataclass_objects()
836+
orig_pydantic_objects = make_list_of_pydantic_objects()
837+
orig_objects = orig_dataclass_objects + orig_pydantic_objects
838+
import random
839+
840+
random.shuffle(orig_objects)
841+
842+
async with Worker(
843+
client,
844+
task_queue=task_queue_name,
845+
workflows=[ComplexCustomUnionTypeWorkflow],
846+
activities=[pydantic_models_activity],
847+
):
848+
round_tripped_objects = await client.execute_workflow(
849+
ComplexCustomUnionTypeWorkflow.run,
850+
orig_objects,
851+
id=str(uuid.uuid4()),
852+
task_queue=task_queue_name,
853+
)
854+
round_tripped_dataclass_objects, round_tripped_pydantic_objects = [], []
855+
for o in round_tripped_objects:
856+
if isinstance(o, MyDataClass):
857+
round_tripped_dataclass_objects.append(o)
858+
elif isinstance(o, BaseModel):
859+
round_tripped_pydantic_objects.append(o)
860+
else:
861+
raise TypeError(f"Unexpected type: {type(o)}")
862+
assert sorted(orig_dataclass_objects) == sorted(round_tripped_dataclass_objects)
863+
assert sorted(orig_pydantic_objects, key=lambda o: o.__class__.__name__) == sorted(
864+
round_tripped_pydantic_objects, key=lambda o: o.__class__.__name__
865+
)
866+
for o in round_tripped_pydantic_objects:
867+
o._check_instance()
868+
869+
802870
@workflow.defn
803871
class PydanticModelUsageWorkflow:
804872
@workflow.run

0 commit comments

Comments
 (0)