1- # pyright: reportMissingTypeArgument=true, reportMissingParameterType=true
21import dataclasses
32import enum
43import functools
54import inspect
65import itertools
76import json
87import re
8+ import types
99from typing import (
1010 Any ,
1111 Callable ,
4242_RE_SNAKE_CASE_2 = re .compile (r"[A-Z]" )
4343
4444
45- __not_valid = object ()
46-
47- __to_snake_case_cache : Dict [str , str ] = {}
45+ __NOT_SET = object ()
4846
4947
5048@functools .lru_cache (maxsize = 1024 )
5149def to_snake_case (s : str ) -> str :
52- result = __to_snake_case_cache .get (s , __not_valid )
53- if result is __not_valid :
54- s = _RE_SNAKE_CASE_1 .sub ("_" , s )
55- if not s :
56- result = s
57- else :
58- result = s [0 ].lower () + _RE_SNAKE_CASE_2 .sub (lambda matched : "_" + matched .group (0 ).lower (), s [1 :])
59- __to_snake_case_cache [s ] = result
60- return cast (str , result )
50+ s = _RE_SNAKE_CASE_1 .sub ("_" , s )
51+ if not s :
52+ return s
53+ return s [0 ].lower () + _RE_SNAKE_CASE_2 .sub (lambda matched : "_" + matched .group (0 ).lower (), s [1 :])
6154
6255
6356_RE_CAMEL_CASE_1 = re .compile (r"^[\-_\.]" )
6457_RE_CAMEL_CASE_2 = re .compile (r"[\-_\.\s]([a-z])" )
6558
66- __to_snake_camel_cache : Dict [str , str ] = {}
67-
6859
6960@functools .lru_cache (maxsize = 1024 )
7061def to_camel_case (s : str ) -> str :
71- result = __to_snake_camel_cache .get (s , __not_valid )
72- if result is __not_valid :
73- s = _RE_CAMEL_CASE_1 .sub ("" , s )
74- if not s :
75- result = s
76- else :
77- result = str (s [0 ]).lower () + _RE_CAMEL_CASE_2 .sub (lambda matched : str (matched .group (1 )).upper (), s [1 :])
78- __to_snake_camel_cache [s ] = result
79- return cast (str , result )
62+ s = _RE_CAMEL_CASE_1 .sub ("" , s )
63+ if not s :
64+ return s
65+ return str (s [0 ]).lower () + _RE_CAMEL_CASE_2 .sub (lambda matched : str (matched .group (1 )).upper (), s [1 :])
8066
8167
8268class CamelSnakeMixin :
@@ -102,49 +88,37 @@ def _decode_case(cls, s: str) -> str:
10288 return s
10389
10490
105- __field_name_cache : Dict [Tuple [Type [Any ], dataclasses .Field ], str ] = {} # type: ignore
106- __NOT_SET = object ()
91+ @functools .lru_cache (maxsize = 1024 )
92+ def _encode_case_for_field_name_cached (t : Type [Any ], field : dataclasses .Field [Any ]) -> str :
93+ alias = field .metadata .get ("alias" , None )
94+ if alias :
95+ return str (alias )
10796
97+ if hasattr (t , "_encode_case" ):
98+ return str (t ._encode_case (field .name ))
10899
109- def encode_case_for_field_name (obj : Any , field : dataclasses .Field ) -> str : # type: ignore
110- t = obj if isinstance (obj , type ) else type (obj )
111- name = __field_name_cache .get ((t , field ), __NOT_SET )
112- if name is __NOT_SET :
113- alias = field .metadata .get ("alias" , None )
114- if alias :
115- name = str (alias )
116- elif hasattr (obj , "_encode_case" ):
117- name = str (obj ._encode_case (field .name ))
118- else :
119- name = field .name
120- __field_name_cache [(t , field )] = name
100+ return field .name
121101
122- return cast (str , name )
123102
124-
125- __decode_case_cache : Dict [Tuple [Type [Any ], str ], str ] = {}
103+ def encode_case_for_field_name (obj : Any , field : dataclasses .Field ) -> str : # type: ignore
104+ t = obj if isinstance (obj , type ) else type (obj )
105+ return _encode_case_for_field_name_cached (t , field )
126106
127107
108+ @functools .lru_cache (maxsize = 1024 )
128109def _decode_case_for_member_name (type : Type [Any ], name : str ) -> str :
129- r = __decode_case_cache .get ((type , name ), __NOT_SET )
130- if r is __NOT_SET :
131- if dataclasses .is_dataclass (type ):
132- field = next (
133- (f for f in get_dataclass_fields (type ) if f .metadata .get ("alias" , None ) == name ),
134- None ,
135- )
136- if field :
137- r = field .name
138-
139- if r is __NOT_SET :
140- if hasattr (type , "_decode_case" ):
141- r = str (type ._decode_case (name ))
142- else :
143- r = name
110+ if dataclasses .is_dataclass (type ):
111+ field = next (
112+ (f for f in get_dataclass_fields (type ) if f .metadata .get ("alias" , None ) == name ),
113+ None ,
114+ )
115+ if field :
116+ return field .name
144117
145- __decode_case_cache [(type , name )] = cast (str , r )
118+ if hasattr (type , "_decode_case" ):
119+ return str (type ._decode_case (name ))
146120
147- return cast ( str , r )
121+ return name
148122
149123
150124NONETYPE = type (None )
@@ -313,6 +287,11 @@ def __from_dict_handle_enum(value: Any, t: Type[Any], strict: bool) -> Tuple[Any
313287def __from_dict_handle_basic_types (value : Any , t : Type [Any ], strict : bool ) -> Tuple [Any , bool ]:
314288 if isinstance (value , t ):
315289 return value , True
290+
291+ if not strict :
292+ if t is float and isinstance (value , int ):
293+ return float (value ), True
294+
316295 return None , False
317296
318297
@@ -330,6 +309,10 @@ def __from_dict_handle_mapping(value: Any, t: Type[Any], strict: bool) -> Tuple[
330309 return None , False
331310
332311
312+ def is_union (tp : Optional [Type [Any ]]) -> bool :
313+ return tp is Union or tp is types .UnionType # type: ignore[comparison-overlap]
314+
315+
333316__from_dict_handlers : List [
334317 Tuple [
335318 Callable [[Type [Any ]], bool ],
@@ -340,7 +323,7 @@ def __from_dict_handle_mapping(value: Any, t: Type[Any], strict: bool) -> Tuple[
340323 lambda t : t in {int , bool , float , str , NONETYPE },
341324 __from_dict_handle_basic_types ,
342325 ),
343- (lambda t : _get_origin_cached (t ) is Union , __from_dict_handle_union ),
326+ (lambda t : is_union ( _get_origin_cached (t )) , __from_dict_handle_union ),
344327 (lambda t : _get_origin_cached (t ) is Literal , __from_dict_handle_literal ),
345328 (__is_enum , __from_dict_handle_enum ),
346329 (
@@ -427,7 +410,7 @@ def from_dict(
427410 if origin is Literal :
428411 continue
429412
430- cased_value : Dict [str , Any ] = {_decode_case_for_member_name (t , k ): v for k , v in value .items ()}
413+ cased_value : Dict [str , Any ] = {_decode_case_for_member_name (t , k ): v for k , v in value .items ()} # type: ignore[arg-type]
431414
432415 type_hints = _get_type_hints_cached (origin or t )
433416 try :
@@ -618,7 +601,7 @@ def __str__(self) -> str:
618601 return f"{ s } (errors = { self .errors !r} )"
619602
620603
621- def validate_types (expected_types : Union [type , Tuple [type , ...], None ], value : Any ) -> List [str ]:
604+ def validate_types (expected_types : Union [Type [ Any ] , Tuple [Type [ Any ] , ...], None ], value : Any ) -> List [str ]:
622605 if expected_types is None :
623606 return []
624607
0 commit comments