1
1
# pyright: reportMissingTypeArgument=true, reportMissingParameterType=true
2
2
import dataclasses
3
3
import enum
4
+ import functools
4
5
import inspect
5
6
import itertools
6
7
import json
7
8
import re
8
9
from typing import (
9
10
Any ,
10
- Callable ,
11
11
Dict ,
12
12
Iterable ,
13
13
List ,
14
14
Literal ,
15
15
Mapping ,
16
16
Optional ,
17
- Protocol ,
18
17
Sequence ,
19
18
Set ,
20
19
Tuple ,
25
24
get_args ,
26
25
get_origin ,
27
26
get_type_hints ,
28
- runtime_checkable ,
29
27
)
30
28
31
29
__all__ = [
48
46
__to_snake_case_cache : Dict [str , str ] = {}
49
47
50
48
49
+ @functools .lru_cache (maxsize = 2048 )
51
50
def to_snake_case (s : str ) -> str :
52
51
result = __to_snake_case_cache .get (s , __not_valid )
53
52
if result is __not_valid :
@@ -66,6 +65,7 @@ def to_snake_case(s: str) -> str:
66
65
__to_snake_camel_cache : Dict [str , str ] = {}
67
66
68
67
68
+ @functools .lru_cache (maxsize = 2048 )
69
69
def to_camel_case (s : str ) -> str :
70
70
result = __to_snake_camel_cache .get (s , __not_valid )
71
71
if result is __not_valid :
@@ -91,20 +91,6 @@ def _decode_case(cls, s: str) -> str:
91
91
return to_snake_case (s )
92
92
93
93
94
- @runtime_checkable
95
- class HasCaseEncoder (Protocol ):
96
- @classmethod
97
- def _encode_case (cls , s : str ) -> str : # pragma: no cover
98
- ...
99
-
100
-
101
- @runtime_checkable
102
- class HasCaseDecoder (Protocol ):
103
- @classmethod
104
- def _decode_case (cls , s : str ) -> str : # pragma: no cover
105
- ...
106
-
107
-
108
94
_T = TypeVar ("_T" )
109
95
110
96
@@ -118,21 +104,13 @@ def _decode_case(cls, s: str) -> str:
118
104
return s
119
105
120
106
121
- __default_config = DefaultConfig ()
122
-
123
-
124
- def __get_config (obj : Any , entry_protocol : Type [_T ]) -> _T :
125
- if isinstance (obj , entry_protocol ):
126
- return obj
127
- return cast (_T , __default_config )
128
-
129
-
130
107
def encode_case (obj : Any , field : dataclasses .Field ) -> str : # type: ignore
131
108
alias = field .metadata .get ("alias" , None )
132
109
if alias :
133
110
return str (alias )
134
-
135
- return __get_config (obj , HasCaseEncoder )._encode_case (field .name ) # type: ignore
111
+ if hasattr (obj , "_encode_case" ):
112
+ return str (obj ._encode_case (field .name ))
113
+ return field .name
136
114
137
115
138
116
def decode_case (type : Type [_T ], name : str ) -> str :
@@ -144,7 +122,10 @@ def decode_case(type: Type[_T], name: str) -> str:
144
122
if field :
145
123
return field .name
146
124
147
- return __get_config (type , HasCaseDecoder )._decode_case (name ) # type: ignore
125
+ if hasattr (type , "_decode_case" ):
126
+ return str (type ._decode_case (name )) # type: ignore[attr-defined]
127
+
128
+ return name
148
129
149
130
150
131
def __default (o : Any ) -> Any :
@@ -365,42 +346,34 @@ def as_dict(
365
346
value : Any ,
366
347
* ,
367
348
remove_defaults : bool = False ,
368
- dict_factory : Callable [[Any ], Dict [str , Any ]] = dict ,
369
349
encode : bool = True ,
370
350
) -> Dict [str , Any ]:
371
351
if not dataclasses .is_dataclass (value ):
372
352
raise TypeError ("as_dict() should be called on dataclass instances" )
373
353
374
- return cast (Dict [str , Any ], _as_dict_inner (value , remove_defaults , dict_factory , encode ))
354
+ return cast (Dict [str , Any ], _as_dict_inner (value , remove_defaults , encode ))
375
355
376
356
377
357
def _as_dict_inner (
378
358
value : Any ,
379
359
remove_defaults : bool ,
380
- dict_factory : Callable [[Any ], Dict [str , Any ]],
381
360
encode : bool = True ,
382
361
) -> Any :
383
362
if dataclasses .is_dataclass (value ):
384
- result = []
385
- for f in dataclasses .fields (value ):
386
- v = _as_dict_inner (getattr (value , f .name ), remove_defaults , dict_factory )
387
-
388
- if remove_defaults and v == f .default :
389
- continue
390
- result .append ((encode_case (value , f ) if encode else f .name , v ))
391
- return dict_factory (result )
363
+ return {
364
+ encode_case (value , f ) if encode else f .name : _as_dict_inner (getattr (value , f .name ), remove_defaults )
365
+ for f in dataclasses .fields (value )
366
+ if not remove_defaults or getattr (value , f .name ) != f .default
367
+ }
392
368
393
369
if isinstance (value , tuple ) and hasattr (value , "_fields" ):
394
- return type ( value )( * [_as_dict_inner (v , remove_defaults , dict_factory ) for v in value ])
370
+ return [_as_dict_inner (v , remove_defaults ) for v in value ]
395
371
396
372
if isinstance (value , (list , tuple )):
397
- return type ( value )( _as_dict_inner (v , remove_defaults , dict_factory ) for v in value )
373
+ return [ _as_dict_inner (v , remove_defaults ) for v in value ]
398
374
399
375
if isinstance (value , dict ):
400
- return type (value )(
401
- (_as_dict_inner (k , remove_defaults , dict_factory ), _as_dict_inner (v , remove_defaults , dict_factory ))
402
- for k , v in value .items ()
403
- )
376
+ return {_as_dict_inner (k , remove_defaults ): _as_dict_inner (v , remove_defaults ) for k , v in value .items ()}
404
377
405
378
return value
406
379
0 commit comments