|
8 | 8 | import re
|
9 | 9 | from typing import (
|
10 | 10 | Any,
|
| 11 | + Callable, |
11 | 12 | Dict,
|
12 | 13 | Iterable,
|
13 | 14 | List,
|
@@ -104,13 +105,24 @@ def _decode_case(cls, s: str) -> str:
|
104 | 105 | return s
|
105 | 106 |
|
106 | 107 |
|
| 108 | +__field_name_cache: Dict[Tuple[Type[Any], dataclasses.Field[Any]], str] = {} |
| 109 | +__NOT_SET = object() |
| 110 | + |
| 111 | + |
107 | 112 | def encode_case(obj: Any, field: dataclasses.Field) -> str: # type: ignore
|
108 |
| - alias = field.metadata.get("alias", None) |
109 |
| - if alias: |
110 |
| - return str(alias) |
111 |
| - if hasattr(obj, "_encode_case"): |
112 |
| - return str(obj._encode_case(field.name)) |
113 |
| - return field.name |
| 113 | + t = obj if isinstance(obj, type) else type(obj) |
| 114 | + name = __field_name_cache.get((t, field), __NOT_SET) |
| 115 | + if name is __NOT_SET: |
| 116 | + alias = field.metadata.get("alias", None) |
| 117 | + if alias: |
| 118 | + name = str(alias) |
| 119 | + elif hasattr(obj, "_encode_case"): |
| 120 | + name = str(obj._encode_case(field.name)) |
| 121 | + else: |
| 122 | + name = field.name |
| 123 | + __field_name_cache[(t, field)] = name |
| 124 | + |
| 125 | + return cast(str, name) |
114 | 126 |
|
115 | 127 |
|
116 | 128 | def decode_case(type: Type[_T], name: str) -> str:
|
@@ -354,28 +366,106 @@ def as_dict(
|
354 | 366 | return cast(Dict[str, Any], _as_dict_inner(value, remove_defaults, encode))
|
355 | 367 |
|
356 | 368 |
|
357 |
| -def _as_dict_inner( |
358 |
| - value: Any, |
359 |
| - remove_defaults: bool, |
360 |
| - encode: bool = True, |
361 |
| -) -> Any: |
362 |
| - if dataclasses.is_dataclass(value): |
363 |
| - return { |
364 |
| - encode_case(value, f) if encode else f.name: _as_dict_inner(getattr(value, f.name), remove_defaults) |
365 |
| - for f in dataclasses.fields(value) |
366 |
| - if not remove_defaults or getattr(value, f.name) != f.default |
367 |
| - } |
| 369 | +NONETYPE = type(None) |
368 | 370 |
|
369 |
| - if isinstance(value, tuple) and hasattr(value, "_fields"): |
370 |
| - return [_as_dict_inner(v, remove_defaults) for v in value] |
371 | 371 |
|
372 |
| - if isinstance(value, (list, tuple)): |
373 |
| - return [_as_dict_inner(v, remove_defaults) for v in value] |
| 372 | +def _handle_basic_types(value: Any, _remove_defaults: bool, _encode: bool) -> Any: |
| 373 | + return value |
374 | 374 |
|
375 |
| - if isinstance(value, dict): |
376 |
| - return {_as_dict_inner(k, remove_defaults): _as_dict_inner(v, remove_defaults) for k, v in value.items()} |
377 | 375 |
|
378 |
| - return value |
| 376 | +__dataclasses_cache: Dict[Type[Any], Tuple[dataclasses.Field[Any], ...]] = {} |
| 377 | + |
| 378 | + |
| 379 | +def _handle_dataclass(value: Any, remove_defaults: bool, encode: bool) -> Dict[str, Any]: |
| 380 | + t = type(value) |
| 381 | + fields = __dataclasses_cache.get(t, None) |
| 382 | + if fields is None: |
| 383 | + fields = dataclasses.fields(value) |
| 384 | + __dataclasses_cache[t] = fields |
| 385 | + return { |
| 386 | + encode_case(t, f) if encode else f.name: _as_dict_inner(getattr(value, f.name), remove_defaults, encode) |
| 387 | + for f in fields |
| 388 | + if not remove_defaults or getattr(value, f.name) != f.default |
| 389 | + } |
| 390 | + |
| 391 | + |
| 392 | +def _handle_named_tuple(value: Any, remove_defaults: bool, encode: bool) -> List[Any]: |
| 393 | + return [_as_dict_inner(v, remove_defaults, encode) for v in value] |
| 394 | + |
| 395 | + |
| 396 | +def _handle_sequence(value: Any, remove_defaults: bool, encode: bool) -> List[Any]: |
| 397 | + return [_as_dict_inner(v, remove_defaults, encode) for v in value] |
| 398 | + |
| 399 | + |
| 400 | +def _handle_dict(value: Any, remove_defaults: bool, encode: bool) -> Dict[Any, Any]: |
| 401 | + return { |
| 402 | + _as_dict_inner(k, remove_defaults, encode): _as_dict_inner(v, remove_defaults, encode) for k, v in value.items() |
| 403 | + } |
| 404 | + |
| 405 | + |
| 406 | +def _handle_enum(value: enum.Enum, remove_defaults: bool, encode: bool) -> Any: |
| 407 | + return _as_dict_inner(value.value, remove_defaults, encode) |
| 408 | + |
| 409 | + |
| 410 | +def _handle_unknown_type(value: Any, _remove_defaults: bool, _encode: bool) -> Any: |
| 411 | + import warnings |
| 412 | + |
| 413 | + warnings.warn(f"Can't handle type {type(value)} with value {value!r}") |
| 414 | + return repr(value) |
| 415 | + |
| 416 | + |
| 417 | +__handlers: List[Tuple[Callable[[Any], bool], Callable[[Any, bool, bool], Any]]] = [ |
| 418 | + ( |
| 419 | + lambda value: type(value) in {int, bool, float, str, NONETYPE}, |
| 420 | + _handle_basic_types, |
| 421 | + ), |
| 422 | + ( |
| 423 | + lambda value: dataclasses.is_dataclass(value), |
| 424 | + _handle_dataclass, |
| 425 | + ), |
| 426 | + (lambda value: isinstance(value, enum.Enum), _handle_enum), |
| 427 | + ( |
| 428 | + lambda value: (isinstance(value, tuple) and hasattr(value, "_fields")), |
| 429 | + _handle_named_tuple, |
| 430 | + ), |
| 431 | + ( |
| 432 | + lambda value: isinstance(value, (list, tuple, set, frozenset)), |
| 433 | + _handle_sequence, |
| 434 | + ), |
| 435 | + ( |
| 436 | + lambda value: isinstance(value, dict), |
| 437 | + _handle_dict, |
| 438 | + ), |
| 439 | + ( |
| 440 | + lambda _value: True, |
| 441 | + _handle_unknown_type, |
| 442 | + ), |
| 443 | +] |
| 444 | + |
| 445 | +__handlers_cache: Dict[Type[Any], Callable[[Any, bool, bool], Any]] = {} |
| 446 | + |
| 447 | + |
| 448 | +def _as_dict_inner( |
| 449 | + value: Any, |
| 450 | + remove_defaults: bool, |
| 451 | + encode: bool, |
| 452 | +) -> Any: |
| 453 | + t = type(value) |
| 454 | + func = __handlers_cache.get(t, None) |
| 455 | + if func is None: |
| 456 | + if t in __handlers_cache: |
| 457 | + return __handlers_cache[t](value, remove_defaults, encode) |
| 458 | + |
| 459 | + for h in __handlers: |
| 460 | + if h[0](value): |
| 461 | + __handlers_cache[t] = h[1] |
| 462 | + func = h[1] |
| 463 | + break |
| 464 | + |
| 465 | + if func is None: |
| 466 | + raise TypeError(f"Can't handle type {t} with value {value!r}") |
| 467 | + |
| 468 | + return func(value, remove_defaults, encode) |
379 | 469 |
|
380 | 470 |
|
381 | 471 | class TypeValidationError(Exception):
|
|
0 commit comments