Skip to content

Commit 2b4ce26

Browse files
committed
perf: optimize performance of as_dict for dataclasses
1 parent 850c751 commit 2b4ce26

File tree

1 file changed

+19
-46
lines changed

1 file changed

+19
-46
lines changed

packages/core/src/robotcode/core/dataclasses.py

Lines changed: 19 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
# pyright: reportMissingTypeArgument=true, reportMissingParameterType=true
22
import dataclasses
33
import enum
4+
import functools
45
import inspect
56
import itertools
67
import json
78
import re
89
from typing import (
910
Any,
10-
Callable,
1111
Dict,
1212
Iterable,
1313
List,
1414
Literal,
1515
Mapping,
1616
Optional,
17-
Protocol,
1817
Sequence,
1918
Set,
2019
Tuple,
@@ -25,7 +24,6 @@
2524
get_args,
2625
get_origin,
2726
get_type_hints,
28-
runtime_checkable,
2927
)
3028

3129
__all__ = [
@@ -48,6 +46,7 @@
4846
__to_snake_case_cache: Dict[str, str] = {}
4947

5048

49+
@functools.lru_cache(maxsize=2048)
5150
def to_snake_case(s: str) -> str:
5251
result = __to_snake_case_cache.get(s, __not_valid)
5352
if result is __not_valid:
@@ -66,6 +65,7 @@ def to_snake_case(s: str) -> str:
6665
__to_snake_camel_cache: Dict[str, str] = {}
6766

6867

68+
@functools.lru_cache(maxsize=2048)
6969
def to_camel_case(s: str) -> str:
7070
result = __to_snake_camel_cache.get(s, __not_valid)
7171
if result is __not_valid:
@@ -91,20 +91,6 @@ def _decode_case(cls, s: str) -> str:
9191
return to_snake_case(s)
9292

9393

94-
@runtime_checkable
95-
class HasCaseEncoder(Protocol):
96-
@classmethod
97-
def _encode_case(cls, s: str) -> str: # pragma: no cover
98-
...
99-
100-
101-
@runtime_checkable
102-
class HasCaseDecoder(Protocol):
103-
@classmethod
104-
def _decode_case(cls, s: str) -> str: # pragma: no cover
105-
...
106-
107-
10894
_T = TypeVar("_T")
10995

11096

@@ -118,21 +104,13 @@ def _decode_case(cls, s: str) -> str:
118104
return s
119105

120106

121-
__default_config = DefaultConfig()
122-
123-
124-
def __get_config(obj: Any, entry_protocol: Type[_T]) -> _T:
125-
if isinstance(obj, entry_protocol):
126-
return obj
127-
return cast(_T, __default_config)
128-
129-
130107
def encode_case(obj: Any, field: dataclasses.Field) -> str: # type: ignore
131108
alias = field.metadata.get("alias", None)
132109
if alias:
133110
return str(alias)
134-
135-
return __get_config(obj, HasCaseEncoder)._encode_case(field.name) # type: ignore
111+
if hasattr(obj, "_encode_case"):
112+
return str(obj._encode_case(field.name))
113+
return field.name
136114

137115

138116
def decode_case(type: Type[_T], name: str) -> str:
@@ -144,7 +122,10 @@ def decode_case(type: Type[_T], name: str) -> str:
144122
if field:
145123
return field.name
146124

147-
return __get_config(type, HasCaseDecoder)._decode_case(name) # type: ignore
125+
if hasattr(type, "_decode_case"):
126+
return str(type._decode_case(name)) # type: ignore[attr-defined]
127+
128+
return name
148129

149130

150131
def __default(o: Any) -> Any:
@@ -365,42 +346,34 @@ def as_dict(
365346
value: Any,
366347
*,
367348
remove_defaults: bool = False,
368-
dict_factory: Callable[[Any], Dict[str, Any]] = dict,
369349
encode: bool = True,
370350
) -> Dict[str, Any]:
371351
if not dataclasses.is_dataclass(value):
372352
raise TypeError("as_dict() should be called on dataclass instances")
373353

374-
return cast(Dict[str, Any], _as_dict_inner(value, remove_defaults, dict_factory, encode))
354+
return cast(Dict[str, Any], _as_dict_inner(value, remove_defaults, encode))
375355

376356

377357
def _as_dict_inner(
378358
value: Any,
379359
remove_defaults: bool,
380-
dict_factory: Callable[[Any], Dict[str, Any]],
381360
encode: bool = True,
382361
) -> Any:
383362
if dataclasses.is_dataclass(value):
384-
result = []
385-
for f in dataclasses.fields(value):
386-
v = _as_dict_inner(getattr(value, f.name), remove_defaults, dict_factory)
387-
388-
if remove_defaults and v == f.default:
389-
continue
390-
result.append((encode_case(value, f) if encode else f.name, v))
391-
return dict_factory(result)
363+
return {
364+
encode_case(value, f) if encode else f.name: _as_dict_inner(getattr(value, f.name), remove_defaults)
365+
for f in dataclasses.fields(value)
366+
if not remove_defaults or getattr(value, f.name) != f.default
367+
}
392368

393369
if isinstance(value, tuple) and hasattr(value, "_fields"):
394-
return type(value)(*[_as_dict_inner(v, remove_defaults, dict_factory) for v in value])
370+
return [_as_dict_inner(v, remove_defaults) for v in value]
395371

396372
if isinstance(value, (list, tuple)):
397-
return type(value)(_as_dict_inner(v, remove_defaults, dict_factory) for v in value)
373+
return [_as_dict_inner(v, remove_defaults) for v in value]
398374

399375
if isinstance(value, dict):
400-
return type(value)(
401-
(_as_dict_inner(k, remove_defaults, dict_factory), _as_dict_inner(v, remove_defaults, dict_factory))
402-
for k, v in value.items()
403-
)
376+
return {_as_dict_inner(k, remove_defaults): _as_dict_inner(v, remove_defaults) for k, v in value.items()}
404377

405378
return value
406379

0 commit comments

Comments
 (0)