2626)
2727
2828from annotated_types import Len
29- from pydantic import BaseModel , Field , WithJsonSchema , create_model
29+ from pydantic import BaseModel , Field , WithJsonSchema
3030from typing_extensions import TypedDict
3131
32- from temporalio import activity , workflow
33- from temporalio .client import Client
34- from temporalio .contrib .pydantic import pydantic_data_converter
35- from temporalio .worker import Worker
36-
3732SequenceType = TypeVar ("SequenceType" , bound = Sequence [Any ])
3833ShortSequence = Annotated [SequenceType , Len (max_length = 2 )]
3934
@@ -573,21 +568,6 @@ def make_pydantic_timedelta_object() -> PydanticTimedeltaModel:
573568 )
574569
575570
576- PydanticModels = Union [
577- StandardTypesModel ,
578- ComplexTypesModel ,
579- SpecialTypesModel ,
580- ParentModel ,
581- FieldFeaturesModel ,
582- AnnotatedFieldsModel ,
583- GenericModel [Any ],
584- UnionModel ,
585- PydanticDatetimeModel ,
586- PydanticDateModel ,
587- PydanticTimedeltaModel ,
588- ]
589-
590-
591571def _assert_datetime_validity (dt : datetime ):
592572 assert isinstance (dt , datetime )
593573 assert issubclass (dt .__class__ , datetime )
@@ -603,6 +583,21 @@ def _assert_timedelta_validity(td: timedelta):
603583 assert issubclass (td .__class__ , timedelta )
604584
605585
586+ PydanticModels = Union [
587+ StandardTypesModel ,
588+ ComplexTypesModel ,
589+ SpecialTypesModel ,
590+ ParentModel ,
591+ FieldFeaturesModel ,
592+ AnnotatedFieldsModel ,
593+ GenericModel [Any ],
594+ UnionModel ,
595+ PydanticDatetimeModel ,
596+ PydanticDateModel ,
597+ PydanticTimedeltaModel ,
598+ ]
599+
600+
606601def make_list_of_pydantic_objects () -> List [PydanticModels ]:
607602 objects = [
608603 make_standard_types_object (),
@@ -622,126 +617,6 @@ def make_list_of_pydantic_objects() -> List[PydanticModels]:
622617 return objects # type: ignore
623618
624619
625- @activity .defn
626- async def pydantic_models_activity (
627- models : List [PydanticModels ],
628- ) -> List [PydanticModels ]:
629- return models
630-
631-
632- @workflow .defn
633- class InstantiateModelsWorkflow :
634- @workflow .run
635- async def run (self ) -> None :
636- make_list_of_pydantic_objects ()
637-
638-
639- @workflow .defn
640- class RoundTripObjectsWorkflow :
641- @workflow .run
642- async def run (self , objects : List [PydanticModels ]) -> List [PydanticModels ]:
643- return await workflow .execute_activity (
644- pydantic_models_activity ,
645- objects ,
646- start_to_close_timeout = timedelta (minutes = 1 ),
647- )
648-
649-
650- def clone_objects (objects : List [PydanticModels ]) -> List [PydanticModels ]:
651- new_objects = []
652- for o in objects :
653- fields = {}
654- for name , f in o .model_fields .items ():
655- fields [name ] = (f .annotation , f )
656- model = create_model (o .__class__ .__name__ , ** fields ) # type: ignore
657- new_objects .append (model (** o .model_dump (by_alias = True )))
658- for old , new in zip (objects , new_objects ):
659- assert old .model_dump () == new .model_dump ()
660- return new_objects
661-
662-
663- @workflow .defn
664- class CloneObjectsWorkflow :
665- @workflow .run
666- async def run (self , objects : List [PydanticModels ]) -> List [PydanticModels ]:
667- return clone_objects (objects )
668-
669-
670- async def test_instantiation_outside_sandbox ():
671- make_list_of_pydantic_objects ()
672-
673-
674- async def test_instantiation_inside_sandbox (client : Client ):
675- new_config = client .config ()
676- new_config ["data_converter" ] = pydantic_data_converter
677- client = Client (** new_config )
678- task_queue_name = str (uuid .uuid4 ())
679-
680- async with Worker (
681- client ,
682- task_queue = task_queue_name ,
683- workflows = [InstantiateModelsWorkflow ],
684- ):
685- await client .execute_workflow (
686- InstantiateModelsWorkflow .run ,
687- id = str (uuid .uuid4 ()),
688- task_queue = task_queue_name ,
689- )
690-
691-
692- async def test_round_trip_pydantic_objects (client : Client ):
693- new_config = client .config ()
694- new_config ["data_converter" ] = pydantic_data_converter
695- client = Client (** new_config )
696- task_queue_name = str (uuid .uuid4 ())
697-
698- orig_objects = make_list_of_pydantic_objects ()
699-
700- async with Worker (
701- client ,
702- task_queue = task_queue_name ,
703- workflows = [RoundTripObjectsWorkflow ],
704- activities = [pydantic_models_activity ],
705- ):
706- returned_objects = await client .execute_workflow (
707- RoundTripObjectsWorkflow .run ,
708- orig_objects ,
709- id = str (uuid .uuid4 ()),
710- task_queue = task_queue_name ,
711- )
712- assert returned_objects == orig_objects
713- for o in returned_objects :
714- o ._check_instance ()
715-
716-
717- async def test_clone_objects_outside_sandbox ():
718- clone_objects (make_list_of_pydantic_objects ())
719-
720-
721- async def test_clone_objects_in_sandbox (client : Client ):
722- new_config = client .config ()
723- new_config ["data_converter" ] = pydantic_data_converter
724- client = Client (** new_config )
725- task_queue_name = str (uuid .uuid4 ())
726-
727- orig_objects = make_list_of_pydantic_objects ()
728-
729- async with Worker (
730- client ,
731- task_queue = task_queue_name ,
732- workflows = [CloneObjectsWorkflow ],
733- ):
734- returned_objects = await client .execute_workflow (
735- CloneObjectsWorkflow .run ,
736- orig_objects ,
737- id = str (uuid .uuid4 ()),
738- task_queue = task_queue_name ,
739- )
740- assert returned_objects == orig_objects
741- for o in returned_objects :
742- o ._check_instance ()
743-
744-
745620@dataclasses .dataclass (order = True )
746621class MyDataClass :
747622 # The name int_field also occurs in StandardTypesModel and currently unions can match them up incorrectly.
@@ -753,170 +628,4 @@ def make_dataclass_objects() -> List[MyDataClass]:
753628
754629
755630ComplexCustomType = Tuple [List [MyDataClass ], List [PydanticModels ]]
756-
757-
758- @workflow .defn
759- class ComplexCustomTypeWorkflow :
760- @workflow .run
761- async def run (
762- self ,
763- input : ComplexCustomType ,
764- ) -> ComplexCustomType :
765- data_classes , pydantic_objects = input
766- pydantic_objects = await workflow .execute_activity (
767- pydantic_models_activity ,
768- pydantic_objects ,
769- start_to_close_timeout = timedelta (minutes = 1 ),
770- )
771- return data_classes , pydantic_objects
772-
773-
774- async def test_complex_custom_type (client : Client ):
775- new_config = client .config ()
776- new_config ["data_converter" ] = pydantic_data_converter
777- client = Client (** new_config )
778- task_queue_name = str (uuid .uuid4 ())
779-
780- orig_dataclass_objects = make_dataclass_objects ()
781- orig_pydantic_objects = make_list_of_pydantic_objects ()
782-
783- async with Worker (
784- client ,
785- task_queue = task_queue_name ,
786- workflows = [ComplexCustomTypeWorkflow ],
787- activities = [pydantic_models_activity ],
788- ):
789- (
790- returned_dataclass_objects ,
791- returned_pydantic_objects ,
792- ) = await client .execute_workflow (
793- ComplexCustomTypeWorkflow .run ,
794- (orig_dataclass_objects , orig_pydantic_objects ),
795- id = str (uuid .uuid4 ()),
796- task_queue = task_queue_name ,
797- )
798- assert orig_dataclass_objects == returned_dataclass_objects
799- assert orig_pydantic_objects == returned_pydantic_objects
800- for o in returned_pydantic_objects :
801- o ._check_instance ()
802-
803-
804631ComplexCustomUnionType = List [Union [MyDataClass , PydanticModels ]]
805-
806-
807- @workflow .defn
808- class ComplexCustomUnionTypeWorkflow :
809- @workflow .run
810- async def run (
811- self ,
812- input : ComplexCustomUnionType ,
813- ) -> ComplexCustomUnionType :
814- data_classes = []
815- pydantic_objects : List [PydanticModels ] = []
816- for o in input :
817- if dataclasses .is_dataclass (o ):
818- data_classes .append (o )
819- elif isinstance (o , BaseModel ):
820- pydantic_objects .append (o )
821- else :
822- raise TypeError (f"Unexpected type: { type (o )} " )
823- pydantic_objects = await workflow .execute_activity (
824- pydantic_models_activity ,
825- pydantic_objects ,
826- start_to_close_timeout = timedelta (minutes = 1 ),
827- )
828- return data_classes + pydantic_objects # type: ignore
829-
830-
831- async def test_complex_custom_union_type (client : Client ):
832- new_config = client .config ()
833- new_config ["data_converter" ] = pydantic_data_converter
834- client = Client (** new_config )
835- task_queue_name = str (uuid .uuid4 ())
836-
837- orig_dataclass_objects = make_dataclass_objects ()
838- orig_pydantic_objects = make_list_of_pydantic_objects ()
839- orig_objects = orig_dataclass_objects + orig_pydantic_objects
840- import random
841-
842- random .shuffle (orig_objects )
843-
844- async with Worker (
845- client ,
846- task_queue = task_queue_name ,
847- workflows = [ComplexCustomUnionTypeWorkflow ],
848- activities = [pydantic_models_activity ],
849- ):
850- returned_objects = await client .execute_workflow (
851- ComplexCustomUnionTypeWorkflow .run ,
852- orig_objects ,
853- id = str (uuid .uuid4 ()),
854- task_queue = task_queue_name ,
855- )
856- returned_dataclass_objects , returned_pydantic_objects = [], []
857- for o in returned_objects :
858- if isinstance (o , MyDataClass ):
859- returned_dataclass_objects .append (o )
860- elif isinstance (o , BaseModel ):
861- returned_pydantic_objects .append (o )
862- else :
863- raise TypeError (f"Unexpected type: { type (o )} " )
864- assert sorted (orig_dataclass_objects ) == sorted (returned_dataclass_objects )
865- assert sorted (orig_pydantic_objects , key = lambda o : o .__class__ .__name__ ) == sorted (
866- returned_pydantic_objects , key = lambda o : o .__class__ .__name__
867- )
868- for o in returned_pydantic_objects :
869- o ._check_instance ()
870-
871-
872- @workflow .defn
873- class PydanticModelUsageWorkflow :
874- @workflow .run
875- async def run (self ) -> None :
876- for o in make_list_of_pydantic_objects ():
877- o ._check_instance ()
878-
879-
880- async def test_pydantic_model_usage_in_workflow (client : Client ):
881- new_config = client .config ()
882- new_config ["data_converter" ] = pydantic_data_converter
883- client = Client (** new_config )
884- task_queue_name = str (uuid .uuid4 ())
885-
886- async with Worker (
887- client ,
888- task_queue = task_queue_name ,
889- workflows = [PydanticModelUsageWorkflow ],
890- ):
891- await client .execute_workflow (
892- PydanticModelUsageWorkflow .run ,
893- id = str (uuid .uuid4 ()),
894- task_queue = task_queue_name ,
895- )
896-
897-
898- @workflow .defn
899- class DatetimeUsageWorkflow :
900- @workflow .run
901- async def run (self ) -> None :
902- dt = workflow .now ()
903- assert isinstance (dt , datetime )
904- assert issubclass (dt .__class__ , datetime )
905-
906-
907- async def test_datetime_usage_in_workflow (client : Client ):
908- new_config = client .config ()
909- new_config ["data_converter" ] = pydantic_data_converter
910- client = Client (** new_config )
911- task_queue_name = str (uuid .uuid4 ())
912-
913- async with Worker (
914- client ,
915- task_queue = task_queue_name ,
916- workflows = [DatetimeUsageWorkflow ],
917- ):
918- await client .execute_workflow (
919- DatetimeUsageWorkflow .run ,
920- id = str (uuid .uuid4 ()),
921- task_queue = task_queue_name ,
922- )
0 commit comments