Skip to content

Commit 40914be

Browse files
committed
Reorganize
1 parent 1770a75 commit 40914be

File tree

4 files changed

+343
-307
lines changed

4 files changed

+343
-307
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing import List
2+
3+
from temporalio import activity
4+
from tests.contrib.pydantic.models import PydanticModels
5+
6+
7+
@activity.defn
8+
async def pydantic_models_activity(
9+
models: List[PydanticModels],
10+
) -> List[PydanticModels]:
11+
return models
Lines changed: 16 additions & 307 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,9 @@
2626
)
2727

2828
from annotated_types import Len
29-
from pydantic import BaseModel, Field, WithJsonSchema, create_model
29+
from pydantic import BaseModel, Field, WithJsonSchema
3030
from typing_extensions import TypedDict
3131

32-
from temporalio import activity, workflow
33-
from temporalio.client import Client
34-
from temporalio.contrib.pydantic import pydantic_data_converter
35-
from temporalio.worker import Worker
36-
3732
SequenceType = TypeVar("SequenceType", bound=Sequence[Any])
3833
ShortSequence = Annotated[SequenceType, Len(max_length=2)]
3934

@@ -573,21 +568,6 @@ def make_pydantic_timedelta_object() -> PydanticTimedeltaModel:
573568
)
574569

575570

576-
PydanticModels = Union[
577-
StandardTypesModel,
578-
ComplexTypesModel,
579-
SpecialTypesModel,
580-
ParentModel,
581-
FieldFeaturesModel,
582-
AnnotatedFieldsModel,
583-
GenericModel[Any],
584-
UnionModel,
585-
PydanticDatetimeModel,
586-
PydanticDateModel,
587-
PydanticTimedeltaModel,
588-
]
589-
590-
591571
def _assert_datetime_validity(dt: datetime):
592572
assert isinstance(dt, datetime)
593573
assert issubclass(dt.__class__, datetime)
@@ -603,6 +583,21 @@ def _assert_timedelta_validity(td: timedelta):
603583
assert issubclass(td.__class__, timedelta)
604584

605585

586+
PydanticModels = Union[
587+
StandardTypesModel,
588+
ComplexTypesModel,
589+
SpecialTypesModel,
590+
ParentModel,
591+
FieldFeaturesModel,
592+
AnnotatedFieldsModel,
593+
GenericModel[Any],
594+
UnionModel,
595+
PydanticDatetimeModel,
596+
PydanticDateModel,
597+
PydanticTimedeltaModel,
598+
]
599+
600+
606601
def make_list_of_pydantic_objects() -> List[PydanticModels]:
607602
objects = [
608603
make_standard_types_object(),
@@ -622,126 +617,6 @@ def make_list_of_pydantic_objects() -> List[PydanticModels]:
622617
return objects # type: ignore
623618

624619

625-
@activity.defn
626-
async def pydantic_models_activity(
627-
models: List[PydanticModels],
628-
) -> List[PydanticModels]:
629-
return models
630-
631-
632-
@workflow.defn
633-
class InstantiateModelsWorkflow:
634-
@workflow.run
635-
async def run(self) -> None:
636-
make_list_of_pydantic_objects()
637-
638-
639-
@workflow.defn
640-
class RoundTripObjectsWorkflow:
641-
@workflow.run
642-
async def run(self, objects: List[PydanticModels]) -> List[PydanticModels]:
643-
return await workflow.execute_activity(
644-
pydantic_models_activity,
645-
objects,
646-
start_to_close_timeout=timedelta(minutes=1),
647-
)
648-
649-
650-
def clone_objects(objects: List[PydanticModels]) -> List[PydanticModels]:
651-
new_objects = []
652-
for o in objects:
653-
fields = {}
654-
for name, f in o.model_fields.items():
655-
fields[name] = (f.annotation, f)
656-
model = create_model(o.__class__.__name__, **fields) # type: ignore
657-
new_objects.append(model(**o.model_dump(by_alias=True)))
658-
for old, new in zip(objects, new_objects):
659-
assert old.model_dump() == new.model_dump()
660-
return new_objects
661-
662-
663-
@workflow.defn
664-
class CloneObjectsWorkflow:
665-
@workflow.run
666-
async def run(self, objects: List[PydanticModels]) -> List[PydanticModels]:
667-
return clone_objects(objects)
668-
669-
670-
async def test_instantiation_outside_sandbox():
671-
make_list_of_pydantic_objects()
672-
673-
674-
async def test_instantiation_inside_sandbox(client: Client):
675-
new_config = client.config()
676-
new_config["data_converter"] = pydantic_data_converter
677-
client = Client(**new_config)
678-
task_queue_name = str(uuid.uuid4())
679-
680-
async with Worker(
681-
client,
682-
task_queue=task_queue_name,
683-
workflows=[InstantiateModelsWorkflow],
684-
):
685-
await client.execute_workflow(
686-
InstantiateModelsWorkflow.run,
687-
id=str(uuid.uuid4()),
688-
task_queue=task_queue_name,
689-
)
690-
691-
692-
async def test_round_trip_pydantic_objects(client: Client):
693-
new_config = client.config()
694-
new_config["data_converter"] = pydantic_data_converter
695-
client = Client(**new_config)
696-
task_queue_name = str(uuid.uuid4())
697-
698-
orig_objects = make_list_of_pydantic_objects()
699-
700-
async with Worker(
701-
client,
702-
task_queue=task_queue_name,
703-
workflows=[RoundTripObjectsWorkflow],
704-
activities=[pydantic_models_activity],
705-
):
706-
returned_objects = await client.execute_workflow(
707-
RoundTripObjectsWorkflow.run,
708-
orig_objects,
709-
id=str(uuid.uuid4()),
710-
task_queue=task_queue_name,
711-
)
712-
assert returned_objects == orig_objects
713-
for o in returned_objects:
714-
o._check_instance()
715-
716-
717-
async def test_clone_objects_outside_sandbox():
718-
clone_objects(make_list_of_pydantic_objects())
719-
720-
721-
async def test_clone_objects_in_sandbox(client: Client):
722-
new_config = client.config()
723-
new_config["data_converter"] = pydantic_data_converter
724-
client = Client(**new_config)
725-
task_queue_name = str(uuid.uuid4())
726-
727-
orig_objects = make_list_of_pydantic_objects()
728-
729-
async with Worker(
730-
client,
731-
task_queue=task_queue_name,
732-
workflows=[CloneObjectsWorkflow],
733-
):
734-
returned_objects = await client.execute_workflow(
735-
CloneObjectsWorkflow.run,
736-
orig_objects,
737-
id=str(uuid.uuid4()),
738-
task_queue=task_queue_name,
739-
)
740-
assert returned_objects == orig_objects
741-
for o in returned_objects:
742-
o._check_instance()
743-
744-
745620
@dataclasses.dataclass(order=True)
746621
class MyDataClass:
747622
# The name int_field also occurs in StandardTypesModel and currently unions can match them up incorrectly.
@@ -753,170 +628,4 @@ def make_dataclass_objects() -> List[MyDataClass]:
753628

754629

755630
ComplexCustomType = Tuple[List[MyDataClass], List[PydanticModels]]
756-
757-
758-
@workflow.defn
759-
class ComplexCustomTypeWorkflow:
760-
@workflow.run
761-
async def run(
762-
self,
763-
input: ComplexCustomType,
764-
) -> ComplexCustomType:
765-
data_classes, pydantic_objects = input
766-
pydantic_objects = await workflow.execute_activity(
767-
pydantic_models_activity,
768-
pydantic_objects,
769-
start_to_close_timeout=timedelta(minutes=1),
770-
)
771-
return data_classes, pydantic_objects
772-
773-
774-
async def test_complex_custom_type(client: Client):
775-
new_config = client.config()
776-
new_config["data_converter"] = pydantic_data_converter
777-
client = Client(**new_config)
778-
task_queue_name = str(uuid.uuid4())
779-
780-
orig_dataclass_objects = make_dataclass_objects()
781-
orig_pydantic_objects = make_list_of_pydantic_objects()
782-
783-
async with Worker(
784-
client,
785-
task_queue=task_queue_name,
786-
workflows=[ComplexCustomTypeWorkflow],
787-
activities=[pydantic_models_activity],
788-
):
789-
(
790-
returned_dataclass_objects,
791-
returned_pydantic_objects,
792-
) = await client.execute_workflow(
793-
ComplexCustomTypeWorkflow.run,
794-
(orig_dataclass_objects, orig_pydantic_objects),
795-
id=str(uuid.uuid4()),
796-
task_queue=task_queue_name,
797-
)
798-
assert orig_dataclass_objects == returned_dataclass_objects
799-
assert orig_pydantic_objects == returned_pydantic_objects
800-
for o in returned_pydantic_objects:
801-
o._check_instance()
802-
803-
804631
ComplexCustomUnionType = List[Union[MyDataClass, PydanticModels]]
805-
806-
807-
@workflow.defn
808-
class ComplexCustomUnionTypeWorkflow:
809-
@workflow.run
810-
async def run(
811-
self,
812-
input: ComplexCustomUnionType,
813-
) -> ComplexCustomUnionType:
814-
data_classes = []
815-
pydantic_objects: List[PydanticModels] = []
816-
for o in input:
817-
if dataclasses.is_dataclass(o):
818-
data_classes.append(o)
819-
elif isinstance(o, BaseModel):
820-
pydantic_objects.append(o)
821-
else:
822-
raise TypeError(f"Unexpected type: {type(o)}")
823-
pydantic_objects = await workflow.execute_activity(
824-
pydantic_models_activity,
825-
pydantic_objects,
826-
start_to_close_timeout=timedelta(minutes=1),
827-
)
828-
return data_classes + pydantic_objects # type: ignore
829-
830-
831-
async def test_complex_custom_union_type(client: Client):
832-
new_config = client.config()
833-
new_config["data_converter"] = pydantic_data_converter
834-
client = Client(**new_config)
835-
task_queue_name = str(uuid.uuid4())
836-
837-
orig_dataclass_objects = make_dataclass_objects()
838-
orig_pydantic_objects = make_list_of_pydantic_objects()
839-
orig_objects = orig_dataclass_objects + orig_pydantic_objects
840-
import random
841-
842-
random.shuffle(orig_objects)
843-
844-
async with Worker(
845-
client,
846-
task_queue=task_queue_name,
847-
workflows=[ComplexCustomUnionTypeWorkflow],
848-
activities=[pydantic_models_activity],
849-
):
850-
returned_objects = await client.execute_workflow(
851-
ComplexCustomUnionTypeWorkflow.run,
852-
orig_objects,
853-
id=str(uuid.uuid4()),
854-
task_queue=task_queue_name,
855-
)
856-
returned_dataclass_objects, returned_pydantic_objects = [], []
857-
for o in returned_objects:
858-
if isinstance(o, MyDataClass):
859-
returned_dataclass_objects.append(o)
860-
elif isinstance(o, BaseModel):
861-
returned_pydantic_objects.append(o)
862-
else:
863-
raise TypeError(f"Unexpected type: {type(o)}")
864-
assert sorted(orig_dataclass_objects) == sorted(returned_dataclass_objects)
865-
assert sorted(orig_pydantic_objects, key=lambda o: o.__class__.__name__) == sorted(
866-
returned_pydantic_objects, key=lambda o: o.__class__.__name__
867-
)
868-
for o in returned_pydantic_objects:
869-
o._check_instance()
870-
871-
872-
@workflow.defn
873-
class PydanticModelUsageWorkflow:
874-
@workflow.run
875-
async def run(self) -> None:
876-
for o in make_list_of_pydantic_objects():
877-
o._check_instance()
878-
879-
880-
async def test_pydantic_model_usage_in_workflow(client: Client):
881-
new_config = client.config()
882-
new_config["data_converter"] = pydantic_data_converter
883-
client = Client(**new_config)
884-
task_queue_name = str(uuid.uuid4())
885-
886-
async with Worker(
887-
client,
888-
task_queue=task_queue_name,
889-
workflows=[PydanticModelUsageWorkflow],
890-
):
891-
await client.execute_workflow(
892-
PydanticModelUsageWorkflow.run,
893-
id=str(uuid.uuid4()),
894-
task_queue=task_queue_name,
895-
)
896-
897-
898-
@workflow.defn
899-
class DatetimeUsageWorkflow:
900-
@workflow.run
901-
async def run(self) -> None:
902-
dt = workflow.now()
903-
assert isinstance(dt, datetime)
904-
assert issubclass(dt.__class__, datetime)
905-
906-
907-
async def test_datetime_usage_in_workflow(client: Client):
908-
new_config = client.config()
909-
new_config["data_converter"] = pydantic_data_converter
910-
client = Client(**new_config)
911-
task_queue_name = str(uuid.uuid4())
912-
913-
async with Worker(
914-
client,
915-
task_queue=task_queue_name,
916-
workflows=[DatetimeUsageWorkflow],
917-
):
918-
await client.execute_workflow(
919-
DatetimeUsageWorkflow.run,
920-
id=str(uuid.uuid4()),
921-
task_queue=task_queue_name,
922-
)

0 commit comments

Comments
 (0)