|
| 1 | +import dataclasses |
1 | 2 | import uuid |
2 | 3 | from datetime import datetime, timedelta |
3 | 4 | from ipaddress import IPv4Address |
4 | | -from typing import Annotated, Any, List, Sequence, TypeVar |
| 5 | +from typing import Annotated, Any, List, Sequence, Tuple, TypeVar |
5 | 6 |
|
6 | 7 | from annotated_types import Len |
7 | 8 | from pydantic import BaseModel, Field, WithJsonSchema |
@@ -35,29 +36,8 @@ class MyPydanticModel(BaseModel): |
35 | 36 | datetime_short_sequence: ShortSequence[List[datetime]] |
36 | 37 |
|
37 | 38 |
|
38 | | -@activity.defn |
39 | | -async def my_activity(models: List[MyPydanticModel]) -> List[MyPydanticModel]: |
40 | | - activity.logger.info("Got models in activity: %s" % models) |
41 | | - return models |
42 | | - |
43 | | - |
44 | | -@workflow.defn |
45 | | -class MyWorkflow: |
46 | | - @workflow.run |
47 | | - async def run(self, models: List[MyPydanticModel]) -> List[MyPydanticModel]: |
48 | | - workflow.logger.info("Got models in workflow: %s" % models) |
49 | | - return await workflow.execute_activity( |
50 | | - my_activity, models, start_to_close_timeout=timedelta(minutes=1) |
51 | | - ) |
52 | | - |
53 | | - |
54 | | -async def test_field_conversion(client: Client): |
55 | | - new_config = client.config() |
56 | | - new_config["data_converter"] = pydantic_data_converter |
57 | | - client = Client(**new_config) |
58 | | - task_queue_name = str(uuid.uuid4()) |
59 | | - |
60 | | - orig_models = [ |
| 39 | +def make_pydantic_objects() -> List[MyPydanticModel]: |
| 40 | + return [ |
61 | 41 | MyPydanticModel( |
62 | 42 | ip_field=IPv4Address("127.0.0.1"), |
63 | 43 | datetime_field=datetime(2000, 1, 2, 3, 4, 5), |
@@ -94,16 +74,95 @@ async def test_field_conversion(client: Client): |
94 | 74 | ), |
95 | 75 | ] |
96 | 76 |
|
| 77 | + |
| 78 | +@activity.defn |
| 79 | +async def list_of_pydantic_models_activity( |
| 80 | + models: List[MyPydanticModel], |
| 81 | +) -> List[MyPydanticModel]: |
| 82 | + return models |
| 83 | + |
| 84 | + |
| 85 | +@workflow.defn |
| 86 | +class ListOfPydanticObjectsWorkflow: |
| 87 | + @workflow.run |
| 88 | + async def run(self, models: List[MyPydanticModel]) -> List[MyPydanticModel]: |
| 89 | + return await workflow.execute_activity( |
| 90 | + list_of_pydantic_models_activity, |
| 91 | + models, |
| 92 | + start_to_close_timeout=timedelta(minutes=1), |
| 93 | + ) |
| 94 | + |
| 95 | + |
| 96 | +async def test_field_conversion(client: Client): |
| 97 | + new_config = client.config() |
| 98 | + new_config["data_converter"] = pydantic_data_converter |
| 99 | + client = Client(**new_config) |
| 100 | + task_queue_name = str(uuid.uuid4()) |
| 101 | + |
| 102 | + orig_pydantic_objects = make_pydantic_objects() |
| 103 | + |
| 104 | + async with Worker( |
| 105 | + client, |
| 106 | + task_queue=task_queue_name, |
| 107 | + workflows=[ListOfPydanticObjectsWorkflow], |
| 108 | + activities=[list_of_pydantic_models_activity], |
| 109 | + ): |
| 110 | + round_tripped_pydantic_objects = await client.execute_workflow( |
| 111 | + ListOfPydanticObjectsWorkflow.run, |
| 112 | + orig_pydantic_objects, |
| 113 | + id=str(uuid.uuid4()), |
| 114 | + task_queue=task_queue_name, |
| 115 | + ) |
| 116 | + assert orig_pydantic_objects == round_tripped_pydantic_objects |
| 117 | + |
| 118 | + |
| 119 | +@dataclasses.dataclass |
| 120 | +class MyDataClass: |
| 121 | + int_field: int |
| 122 | + |
| 123 | + |
| 124 | +def make_dataclass_objects() -> List[MyDataClass]: |
| 125 | + return [MyDataClass(int_field=7)] |
| 126 | + |
| 127 | + |
| 128 | +@workflow.defn |
| 129 | +class MixedCollectionTypesWorkflow: |
| 130 | + @workflow.run |
| 131 | + async def run( |
| 132 | + self, input: Tuple[List[MyDataClass], List[MyPydanticModel]] |
| 133 | + ) -> Tuple[List[MyDataClass], List[MyPydanticModel]]: |
| 134 | + data_classes, pydantic_objects = input |
| 135 | + pydantic_objects = await workflow.execute_activity( |
| 136 | + list_of_pydantic_models_activity, |
| 137 | + pydantic_objects, |
| 138 | + start_to_close_timeout=timedelta(minutes=1), |
| 139 | + ) |
| 140 | + return data_classes, pydantic_objects |
| 141 | + |
| 142 | + |
| 143 | +async def test_mixed_collection_types(client: Client): |
| 144 | + new_config = client.config() |
| 145 | + new_config["data_converter"] = pydantic_data_converter |
| 146 | + client = Client(**new_config) |
| 147 | + task_queue_name = str(uuid.uuid4()) |
| 148 | + |
| 149 | + orig_dataclass_objects = make_dataclass_objects() |
| 150 | + orig_pydantic_objects = make_pydantic_objects() |
| 151 | + |
97 | 152 | async with Worker( |
98 | 153 | client, |
99 | 154 | task_queue=task_queue_name, |
100 | | - workflows=[MyWorkflow], |
101 | | - activities=[my_activity], |
| 155 | + workflows=[MixedCollectionTypesWorkflow], |
| 156 | + activities=[list_of_pydantic_models_activity], |
102 | 157 | ): |
103 | | - result = await client.execute_workflow( |
104 | | - MyWorkflow.run, |
105 | | - orig_models, |
| 158 | + ( |
| 159 | + round_tripped_dataclass_objects, |
| 160 | + round_tripped_pydantic_objects, |
| 161 | + ) = await client.execute_workflow( |
| 162 | + MixedCollectionTypesWorkflow.run, |
| 163 | + (orig_dataclass_objects, orig_pydantic_objects), |
106 | 164 | id=str(uuid.uuid4()), |
107 | 165 | task_queue=task_queue_name, |
108 | 166 | ) |
109 | | - assert orig_models == result |
| 167 | + assert orig_dataclass_objects == round_tripped_dataclass_objects |
| 168 | + assert orig_pydantic_objects == round_tripped_pydantic_objects |
0 commit comments