Skip to content

Commit 428226a

Browse files
committed
refactor(core): simplify casing caches and add support for PEP 604 unions
1 parent a956c46 commit 428226a

File tree

1 file changed

+44
-61
lines changed

1 file changed

+44
-61
lines changed

packages/core/src/robotcode/core/utils/dataclasses.py

Lines changed: 44 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# pyright: reportMissingTypeArgument=true, reportMissingParameterType=true
21
import dataclasses
32
import enum
43
import functools
54
import inspect
65
import itertools
76
import json
87
import re
8+
import types
99
from typing import (
1010
Any,
1111
Callable,
@@ -42,41 +42,27 @@
4242
_RE_SNAKE_CASE_2 = re.compile(r"[A-Z]")
4343

4444

45-
__not_valid = object()
46-
47-
__to_snake_case_cache: Dict[str, str] = {}
45+
__NOT_SET = object()
4846

4947

5048
@functools.lru_cache(maxsize=1024)
5149
def to_snake_case(s: str) -> str:
52-
result = __to_snake_case_cache.get(s, __not_valid)
53-
if result is __not_valid:
54-
s = _RE_SNAKE_CASE_1.sub("_", s)
55-
if not s:
56-
result = s
57-
else:
58-
result = s[0].lower() + _RE_SNAKE_CASE_2.sub(lambda matched: "_" + matched.group(0).lower(), s[1:])
59-
__to_snake_case_cache[s] = result
60-
return cast(str, result)
50+
s = _RE_SNAKE_CASE_1.sub("_", s)
51+
if not s:
52+
return s
53+
return s[0].lower() + _RE_SNAKE_CASE_2.sub(lambda matched: "_" + matched.group(0).lower(), s[1:])
6154

6255

6356
_RE_CAMEL_CASE_1 = re.compile(r"^[\-_\.]")
6457
_RE_CAMEL_CASE_2 = re.compile(r"[\-_\.\s]([a-z])")
6558

66-
__to_snake_camel_cache: Dict[str, str] = {}
67-
6859

6960
@functools.lru_cache(maxsize=1024)
7061
def to_camel_case(s: str) -> str:
71-
result = __to_snake_camel_cache.get(s, __not_valid)
72-
if result is __not_valid:
73-
s = _RE_CAMEL_CASE_1.sub("", s)
74-
if not s:
75-
result = s
76-
else:
77-
result = str(s[0]).lower() + _RE_CAMEL_CASE_2.sub(lambda matched: str(matched.group(1)).upper(), s[1:])
78-
__to_snake_camel_cache[s] = result
79-
return cast(str, result)
62+
s = _RE_CAMEL_CASE_1.sub("", s)
63+
if not s:
64+
return s
65+
return str(s[0]).lower() + _RE_CAMEL_CASE_2.sub(lambda matched: str(matched.group(1)).upper(), s[1:])
8066

8167

8268
class CamelSnakeMixin:
@@ -102,49 +88,37 @@ def _decode_case(cls, s: str) -> str:
10288
return s
10389

10490

105-
__field_name_cache: Dict[Tuple[Type[Any], dataclasses.Field], str] = {} # type: ignore
106-
__NOT_SET = object()
91+
@functools.lru_cache(maxsize=1024)
92+
def _encode_case_for_field_name_cached(t: Type[Any], field: dataclasses.Field[Any]) -> str:
93+
alias = field.metadata.get("alias", None)
94+
if alias:
95+
return str(alias)
10796

97+
if hasattr(t, "_encode_case"):
98+
return str(t._encode_case(field.name))
10899

109-
def encode_case_for_field_name(obj: Any, field: dataclasses.Field) -> str: # type: ignore
110-
t = obj if isinstance(obj, type) else type(obj)
111-
name = __field_name_cache.get((t, field), __NOT_SET)
112-
if name is __NOT_SET:
113-
alias = field.metadata.get("alias", None)
114-
if alias:
115-
name = str(alias)
116-
elif hasattr(obj, "_encode_case"):
117-
name = str(obj._encode_case(field.name))
118-
else:
119-
name = field.name
120-
__field_name_cache[(t, field)] = name
100+
return field.name
121101

122-
return cast(str, name)
123102

124-
125-
__decode_case_cache: Dict[Tuple[Type[Any], str], str] = {}
103+
def encode_case_for_field_name(obj: Any, field: dataclasses.Field) -> str: # type: ignore
104+
t = obj if isinstance(obj, type) else type(obj)
105+
return _encode_case_for_field_name_cached(t, field)
126106

127107

108+
@functools.lru_cache(maxsize=1024)
128109
def _decode_case_for_member_name(type: Type[Any], name: str) -> str:
129-
r = __decode_case_cache.get((type, name), __NOT_SET)
130-
if r is __NOT_SET:
131-
if dataclasses.is_dataclass(type):
132-
field = next(
133-
(f for f in get_dataclass_fields(type) if f.metadata.get("alias", None) == name),
134-
None,
135-
)
136-
if field:
137-
r = field.name
138-
139-
if r is __NOT_SET:
140-
if hasattr(type, "_decode_case"):
141-
r = str(type._decode_case(name))
142-
else:
143-
r = name
110+
if dataclasses.is_dataclass(type):
111+
field = next(
112+
(f for f in get_dataclass_fields(type) if f.metadata.get("alias", None) == name),
113+
None,
114+
)
115+
if field:
116+
return field.name
144117

145-
__decode_case_cache[(type, name)] = cast(str, r)
118+
if hasattr(type, "_decode_case"):
119+
return str(type._decode_case(name))
146120

147-
return cast(str, r)
121+
return name
148122

149123

150124
NONETYPE = type(None)
@@ -313,6 +287,11 @@ def __from_dict_handle_enum(value: Any, t: Type[Any], strict: bool) -> Tuple[Any
313287
def __from_dict_handle_basic_types(value: Any, t: Type[Any], strict: bool) -> Tuple[Any, bool]:
314288
if isinstance(value, t):
315289
return value, True
290+
291+
if not strict:
292+
if t is float and isinstance(value, int):
293+
return float(value), True
294+
316295
return None, False
317296

318297

@@ -330,6 +309,10 @@ def __from_dict_handle_mapping(value: Any, t: Type[Any], strict: bool) -> Tuple[
330309
return None, False
331310

332311

312+
def is_union(tp: Optional[Type[Any]]) -> bool:
313+
return tp is Union or tp is types.UnionType # type: ignore[comparison-overlap]
314+
315+
333316
__from_dict_handlers: List[
334317
Tuple[
335318
Callable[[Type[Any]], bool],
@@ -340,7 +323,7 @@ def __from_dict_handle_mapping(value: Any, t: Type[Any], strict: bool) -> Tuple[
340323
lambda t: t in {int, bool, float, str, NONETYPE},
341324
__from_dict_handle_basic_types,
342325
),
343-
(lambda t: _get_origin_cached(t) is Union, __from_dict_handle_union),
326+
(lambda t: is_union(_get_origin_cached(t)), __from_dict_handle_union),
344327
(lambda t: _get_origin_cached(t) is Literal, __from_dict_handle_literal),
345328
(__is_enum, __from_dict_handle_enum),
346329
(
@@ -427,7 +410,7 @@ def from_dict(
427410
if origin is Literal:
428411
continue
429412

430-
cased_value: Dict[str, Any] = {_decode_case_for_member_name(t, k): v for k, v in value.items()}
413+
cased_value: Dict[str, Any] = {_decode_case_for_member_name(t, k): v for k, v in value.items()} # type: ignore[arg-type]
431414

432415
type_hints = _get_type_hints_cached(origin or t)
433416
try:
@@ -618,7 +601,7 @@ def __str__(self) -> str:
618601
return f"{s} (errors = {self.errors!r})"
619602

620603

621-
def validate_types(expected_types: Union[type, Tuple[type, ...], None], value: Any) -> List[str]:
604+
def validate_types(expected_types: Union[Type[Any], Tuple[Type[Any], ...], None], value: Any) -> List[str]:
622605
if expected_types is None:
623606
return []
624607

0 commit comments

Comments
 (0)