|
1 | | -import collections |
2 | 1 | import dataclasses |
3 | | -import decimal |
4 | | -import fractions |
5 | | -import re |
6 | | -import uuid |
7 | | -from datetime import date, datetime, time, timedelta, timezone |
8 | | -from enum import Enum, IntEnum |
9 | | -from ipaddress import IPv4Address |
| 2 | +from datetime import date, datetime, timedelta |
10 | 3 | from pathlib import Path |
11 | 4 | from typing import ( |
12 | 5 | Annotated, |
13 | 6 | Any, |
14 | 7 | Dict, |
15 | 8 | Generic, |
16 | | - Hashable, |
17 | 9 | List, |
18 | | - NamedTuple, |
19 | | - Optional, |
20 | | - Pattern, |
21 | 10 | Sequence, |
22 | | - Set, |
23 | 11 | Tuple, |
24 | 12 | TypeVar, |
25 | 13 | Union, |
26 | 14 | ) |
27 | 15 |
|
28 | 16 | from annotated_types import Len |
29 | 17 | from pydantic import BaseModel, Field, WithJsonSchema |
30 | | -from typing_extensions import TypedDict |
31 | 18 |
|
32 | | -SequenceType = TypeVar("SequenceType", bound=Sequence[Any]) |
33 | | -ShortSequence = Annotated[SequenceType, Len(max_length=2)] |
34 | | - |
35 | | - |
36 | | -class FruitEnum(str, Enum): |
37 | | - apple = "apple" |
38 | | - banana = "banana" |
39 | | - |
40 | | - |
41 | | -class NumberEnum(IntEnum): |
42 | | - one = 1 |
43 | | - two = 2 |
44 | | - |
45 | | - |
46 | | -class UserTypedDict(TypedDict): |
47 | | - name: str |
48 | | - id: int |
49 | | - |
50 | | - |
51 | | -class TypedDictModel(BaseModel): |
52 | | - typed_dict_field: UserTypedDict |
53 | | - |
54 | | - def _check_instance(self) -> None: |
55 | | - assert isinstance(self.typed_dict_field, dict) |
56 | | - assert self.typed_dict_field == {"name": "username", "id": 7} |
57 | | - |
58 | | - |
59 | | -def make_typed_dict_object() -> TypedDictModel: |
60 | | - return TypedDictModel(typed_dict_field={"name": "username", "id": 7}) |
61 | | - |
62 | | - |
63 | | -class StandardTypesModel(BaseModel): |
64 | | - # Boolean |
65 | | - bool_field: bool |
66 | | - bool_field_int: bool |
67 | | - bool_field_str: bool |
68 | | - |
69 | | - # Numbers |
70 | | - int_field: int |
71 | | - float_field: float |
72 | | - decimal_field: decimal.Decimal |
73 | | - complex_field: complex |
74 | | - fraction_field: fractions.Fraction |
75 | | - |
76 | | - # Strings and Bytes |
77 | | - str_field: str |
78 | | - bytes_field: bytes |
79 | | - |
80 | | - # None |
81 | | - none_field: None |
82 | | - |
83 | | - # Enums |
84 | | - str_enum_field: FruitEnum |
85 | | - int_enum_field: NumberEnum |
86 | | - |
87 | | - # Collections |
88 | | - list_field: list |
89 | | - tuple_field: tuple |
90 | | - set_field: set |
91 | | - frozenset_field: frozenset |
92 | | - deque_field: collections.deque |
93 | | - sequence_field: Sequence[int] |
94 | | - # Iterable[int] supported but not tested since original vs round-tripped do not compare equal |
95 | | - |
96 | | - # Mappings |
97 | | - dict_field: dict |
98 | | - # defaultdict_field: collections.defaultdict |
99 | | - counter_field: collections.Counter |
100 | | - typed_dict_field: UserTypedDict |
101 | | - |
102 | | - # Other Types |
103 | | - pattern_field: Pattern |
104 | | - hashable_field: Hashable |
105 | | - any_field: Any |
106 | | - # callable_field: Callable |
107 | | - |
108 | | - def _check_instance(self) -> None: |
109 | | - # Boolean checks |
110 | | - assert isinstance(self.bool_field, bool) |
111 | | - assert self.bool_field is True |
112 | | - assert isinstance(self.bool_field_int, bool) |
113 | | - assert self.bool_field_int is True |
114 | | - assert isinstance(self.bool_field_str, bool) |
115 | | - assert self.bool_field_str is True |
116 | | - |
117 | | - # Number checks |
118 | | - assert isinstance(self.int_field, int) |
119 | | - assert self.int_field == 42 |
120 | | - assert isinstance(self.float_field, float) |
121 | | - assert self.float_field == 3.14 |
122 | | - assert isinstance(self.decimal_field, decimal.Decimal) |
123 | | - assert self.decimal_field == decimal.Decimal("3.14") |
124 | | - assert isinstance(self.complex_field, complex) |
125 | | - assert self.complex_field == complex(1, 2) |
126 | | - assert isinstance(self.fraction_field, fractions.Fraction) |
127 | | - assert self.fraction_field == fractions.Fraction(22, 7) |
128 | | - |
129 | | - # String and Bytes checks |
130 | | - assert isinstance(self.str_field, str) |
131 | | - assert self.str_field == "hello" |
132 | | - assert isinstance(self.bytes_field, bytes) |
133 | | - assert self.bytes_field == b"world" |
134 | | - |
135 | | - # None check |
136 | | - assert self.none_field is None |
137 | | - |
138 | | - # Enum checks |
139 | | - assert isinstance(self.str_enum_field, Enum) |
140 | | - assert isinstance(self.int_enum_field, IntEnum) |
141 | | - |
142 | | - # Collection checks |
143 | | - assert isinstance(self.list_field, list) |
144 | | - assert self.list_field == [1, 2, 3] |
145 | | - assert isinstance(self.tuple_field, tuple) |
146 | | - assert self.tuple_field == (1, 2, 3) |
147 | | - assert isinstance(self.set_field, set) |
148 | | - assert self.set_field == {1, 2, 3} |
149 | | - assert isinstance(self.frozenset_field, frozenset) |
150 | | - assert self.frozenset_field == frozenset([1, 2, 3]) |
151 | | - assert isinstance(self.deque_field, collections.deque) |
152 | | - assert list(self.deque_field) == [1, 2, 3] |
153 | | - assert isinstance(self.sequence_field, list) |
154 | | - assert list(self.sequence_field) == [1, 2, 3] |
155 | | - |
156 | | - # Mapping checks |
157 | | - assert isinstance(self.dict_field, dict) |
158 | | - assert self.dict_field == {"a": 1, "b": 2} |
159 | | - # assert isinstance(self.defaultdict_field, collections.defaultdict) |
160 | | - # assert dict(self.defaultdict_field) == {"a": 1, "b": 2} |
161 | | - assert isinstance(self.counter_field, collections.Counter) |
162 | | - assert dict(self.counter_field) == {"a": 1, "b": 2} |
163 | | - assert isinstance(self.typed_dict_field, dict) |
164 | | - assert self.typed_dict_field == {"name": "username", "id": 7} |
165 | | - |
166 | | - # Other type checks |
167 | | - assert isinstance(self.pattern_field, Pattern) |
168 | | - assert self.pattern_field.pattern == r"\d+" |
169 | | - assert isinstance(self.hashable_field, Hashable) |
170 | | - assert self.hashable_field == "test" |
171 | | - assert self.any_field == "anything goes" |
172 | | - # assert callable(self.callable_field) |
173 | | - |
174 | | - |
175 | | -def make_standard_types_object() -> StandardTypesModel: |
176 | | - return StandardTypesModel( |
177 | | - # Boolean |
178 | | - bool_field=True, |
179 | | - bool_field_int=1, # type: ignore |
180 | | - bool_field_str="true", # type: ignore |
181 | | - # Numbers |
182 | | - int_field=42, |
183 | | - float_field=3.14, |
184 | | - decimal_field=decimal.Decimal("3.14"), |
185 | | - complex_field=complex(1, 2), |
186 | | - fraction_field=fractions.Fraction(22, 7), |
187 | | - # Strings and Bytes |
188 | | - str_field="hello", |
189 | | - bytes_field=b"world", |
190 | | - # None |
191 | | - none_field=None, |
192 | | - # Enums |
193 | | - str_enum_field=FruitEnum.apple, |
194 | | - int_enum_field=NumberEnum.one, |
195 | | - # Collections |
196 | | - # these cast input to list, tuple, set, etc. |
197 | | - list_field={1, 2, 3}, # type: ignore |
198 | | - tuple_field=(1, 2, 3), |
199 | | - set_field={1, 2, 3}, |
200 | | - frozenset_field=frozenset([1, 2, 3]), |
201 | | - deque_field=collections.deque([1, 2, 3]), |
202 | | - # other sequence types are converted to list, as documented |
203 | | - sequence_field=[1, 2, 3], |
204 | | - # Mappings |
205 | | - dict_field={"a": 1, "b": 2}, |
206 | | - # defaultdict_field=collections.defaultdict(int, {"a": 1, "b": 2}), |
207 | | - counter_field=collections.Counter({"a": 1, "b": 2}), |
208 | | - typed_dict_field={"name": "username", "id": 7}, |
209 | | - # Other Types |
210 | | - pattern_field=re.compile(r"\d+"), |
211 | | - hashable_field="test", |
212 | | - any_field="anything goes", |
213 | | - # callable_field=lambda x: x, |
| 19 | +from temporalio import workflow |
| 20 | + |
| 21 | +# Define some of the models outside the sandbox |
| 22 | +with workflow.unsafe.imports_passed_through(): |
| 23 | + from tests.contrib.pydantic.models_2 import ( |
| 24 | + ComplexTypesModel, |
| 25 | + SpecialTypesModel, |
| 26 | + StandardTypesModel, |
| 27 | + make_complex_types_object, |
| 28 | + make_special_types_object, |
| 29 | + make_standard_types_object, |
214 | 30 | ) |
215 | 31 |
|
216 | | - |
217 | | -class Point(NamedTuple): |
218 | | - x: int |
219 | | - y: int |
220 | | - |
221 | | - |
222 | | -class ComplexTypesModel(BaseModel): |
223 | | - list_field: List[str] |
224 | | - dict_field: Dict[str, int] |
225 | | - set_field: Set[int] |
226 | | - tuple_field: Tuple[str, int] |
227 | | - union_field: Union[str, int] |
228 | | - optional_field: Optional[str] |
229 | | - named_tuple_field: Point |
230 | | - |
231 | | - def _check_instance(self) -> None: |
232 | | - assert isinstance(self.list_field, list) |
233 | | - assert isinstance(self.dict_field, dict) |
234 | | - assert isinstance(self.set_field, set) |
235 | | - assert isinstance(self.tuple_field, tuple) |
236 | | - assert isinstance(self.union_field, str) |
237 | | - assert isinstance(self.optional_field, str) |
238 | | - assert self.list_field == ["a", "b", "c"] |
239 | | - assert self.dict_field == {"x": 1, "y": 2} |
240 | | - assert self.set_field == {1, 2, 3} |
241 | | - assert self.tuple_field == ("hello", 42) |
242 | | - assert self.union_field == "string_or_int" |
243 | | - assert self.optional_field == "present" |
244 | | - assert self.named_tuple_field == Point(x=1, y=2) |
245 | | - |
246 | | - |
247 | | -def make_complex_types_object() -> ComplexTypesModel: |
248 | | - return ComplexTypesModel( |
249 | | - list_field=["a", "b", "c"], |
250 | | - dict_field={"x": 1, "y": 2}, |
251 | | - set_field={1, 2, 3}, |
252 | | - tuple_field=("hello", 42), |
253 | | - union_field="string_or_int", |
254 | | - optional_field="present", |
255 | | - named_tuple_field=Point(x=1, y=2), |
256 | | - ) |
257 | | - |
258 | | - |
259 | | -class SpecialTypesModel(BaseModel): |
260 | | - datetime_field: datetime |
261 | | - datetime_field_int: datetime |
262 | | - datetime_field_float: datetime |
263 | | - datetime_field_str_formatted: datetime |
264 | | - datetime_field_str_int: datetime |
265 | | - datetime_field_date: datetime |
266 | | - |
267 | | - time_field: time |
268 | | - time_field_str: time |
269 | | - |
270 | | - date_field: date |
271 | | - timedelta_field: timedelta |
272 | | - path_field: Path |
273 | | - uuid_field: uuid.UUID |
274 | | - ip_field: IPv4Address |
275 | | - |
276 | | - def _check_instance(self) -> None: |
277 | | - dt = datetime(2000, 1, 2, 3, 4, 5) |
278 | | - dtz = datetime(2000, 1, 2, 3, 4, 5, tzinfo=timezone.utc) |
279 | | - assert isinstance(self.datetime_field, datetime) |
280 | | - assert isinstance(self.datetime_field_int, datetime) |
281 | | - assert isinstance(self.datetime_field_float, datetime) |
282 | | - assert isinstance(self.datetime_field_str_formatted, datetime) |
283 | | - assert isinstance(self.datetime_field_str_int, datetime) |
284 | | - assert isinstance(self.datetime_field_date, datetime) |
285 | | - assert isinstance(self.timedelta_field, timedelta) |
286 | | - assert isinstance(self.path_field, Path) |
287 | | - assert isinstance(self.uuid_field, uuid.UUID) |
288 | | - assert isinstance(self.ip_field, IPv4Address) |
289 | | - assert self.datetime_field == dt |
290 | | - assert self.datetime_field_int == dtz |
291 | | - assert self.datetime_field_float == dtz |
292 | | - assert self.datetime_field_str_formatted == dtz |
293 | | - assert self.datetime_field_str_int == dtz |
294 | | - assert self.datetime_field_date == datetime(2000, 1, 2) |
295 | | - assert self.time_field == time(3, 4, 5) |
296 | | - assert self.time_field_str == time(3, 4, 5, tzinfo=timezone.utc) |
297 | | - assert self.date_field == date(2000, 1, 2) |
298 | | - assert self.timedelta_field == timedelta(days=1, hours=2) |
299 | | - assert self.path_field == Path("test/path") |
300 | | - assert self.uuid_field == uuid.UUID("12345678-1234-5678-1234-567812345678") |
301 | | - assert self.ip_field == IPv4Address("127.0.0.1") |
302 | | - |
303 | | - |
304 | | -def make_special_types_object() -> SpecialTypesModel: |
305 | | - return SpecialTypesModel( |
306 | | - datetime_field=datetime(2000, 1, 2, 3, 4, 5), |
307 | | - # 946800245 |
308 | | - datetime_field_int=946782245, # type: ignore |
309 | | - datetime_field_float=946782245.0, # type: ignore |
310 | | - datetime_field_str_formatted="2000-01-02T03:04:05Z", # type: ignore |
311 | | - datetime_field_str_int="946782245", # type: ignore |
312 | | - datetime_field_date=datetime(2000, 1, 2), |
313 | | - time_field=time(3, 4, 5), |
314 | | - time_field_str="03:04:05Z", # type: ignore |
315 | | - date_field=date(2000, 1, 2), |
316 | | - timedelta_field=timedelta(days=1, hours=2), |
317 | | - path_field=Path("test/path"), |
318 | | - uuid_field=uuid.UUID("12345678-1234-5678-1234-567812345678"), |
319 | | - ip_field=IPv4Address("127.0.0.1"), |
320 | | - ) |
| 32 | +SequenceType = TypeVar("SequenceType", bound=Sequence[Any]) |
| 33 | +ShortSequence = Annotated[SequenceType, Len(max_length=2)] |
321 | 34 |
|
322 | 35 |
|
323 | 36 | class ChildModel(BaseModel): |
|
0 commit comments