Skip to content

Commit e6b899a

Browse files
authored
optimize frozen dict get item (#6021)
* optimize frozen dict get item * asdict is weird * make that into a dataclass generic
1 parent e4300ad commit e6b899a

File tree

4 files changed

+39
-20
lines changed

4 files changed

+39
-20
lines changed

reflex/components/component.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,8 +515,10 @@ def _deterministic_hash(value: object) -> str:
515515
return _hash_str(
516516
str((value._js_expr, _deterministic_hash(value._get_all_var_data())))
517517
)
518-
if isinstance(value, VarData):
519-
return _hash_dict(dataclasses.asdict(value))
518+
if dataclasses.is_dataclass(value):
519+
return _hash_dict({
520+
k.name: getattr(value, k.name) for k in dataclasses.fields(value)
521+
})
520522
if isinstance(value, BaseComponent):
521523
# If the value is a component, hash its rendered code.
522524
return _hash_dict(value.render())

reflex/istate/data.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dataclasses
44
from collections.abc import Mapping
5+
from types import MappingProxyType
56
from typing import TYPE_CHECKING
67
from urllib.parse import _NetlocResultMixinStr, parse_qsl, urlsplit
78

@@ -12,20 +13,34 @@
1213

1314
@dataclasses.dataclass(frozen=True, init=False)
1415
class _FrozenDictStrStr(Mapping[str, str]):
15-
_data: tuple[tuple[str, str], ...]
16+
_data: MappingProxyType[str, str]
1617

1718
def __init__(self, **kwargs):
18-
object.__setattr__(self, "_data", tuple(sorted(kwargs.items())))
19+
object.__setattr__(
20+
self, "_data", MappingProxyType(dict(sorted(kwargs.items())))
21+
)
1922

2023
def __getitem__(self, key: str) -> str:
21-
return dict(self._data)[key]
24+
return self._data[key]
2225

2326
def __iter__(self):
24-
return (x[0] for x in self._data)
27+
return iter(self._data)
2528

2629
def __len__(self):
2730
return len(self._data)
2831

32+
def __hash__(self) -> int:
33+
return hash(frozenset(self._data.items()))
34+
35+
def __getstate__(self) -> object:
36+
return dict(self._data)
37+
38+
def __setstate__(self, state: object) -> None:
39+
if not isinstance(state, dict):
40+
msg = "Invalid state for _FrozenDictStrStr"
41+
raise TypeError(msg)
42+
object.__setattr__(self, "_data", MappingProxyType(state))
43+
2944

3045
@dataclasses.dataclass(frozen=True)
3146
class _HeaderData:
@@ -170,7 +185,7 @@ def from_router_data(cls, router_data: dict) -> "PageData":
170185

171186
@serializer(to=dict)
172187
def _serialize_page_data(obj: PageData) -> dict:
173-
return dataclasses.asdict(obj)
188+
return {key.name: getattr(obj, key.name) for key in dataclasses.fields(obj)}
174189

175190

176191
@dataclasses.dataclass(frozen=True)
@@ -200,7 +215,7 @@ def from_router_data(cls, router_data: dict) -> "SessionData":
200215

201216
@serializer(to=dict)
202217
def _serialize_session_data(obj: SessionData) -> dict:
203-
return dataclasses.asdict(obj)
218+
return {key.name: getattr(obj, key.name) for key in dataclasses.fields(obj)}
204219

205220

206221
@dataclasses.dataclass(frozen=True)

reflex/vars/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,8 +1508,8 @@ def _create_literal_var(
15081508
if dataclasses.is_dataclass(value) and not isinstance(value, type):
15091509
return LiteralObjectVar.create(
15101510
{
1511-
k: (None if callable(v) else v)
1512-
for k, v in dataclasses.asdict(value).items()
1511+
k.name: (None if callable(v := getattr(value, k.name)) else v)
1512+
for k in dataclasses.fields(value)
15131513
},
15141514
_var_type=type(value),
15151515
_var_data=_var_data,
@@ -1591,8 +1591,8 @@ def _get_all_var_data_without_creating_var_dispatch(
15911591

15921592
if dataclasses.is_dataclass(value) and not isinstance(value, type):
15931593
return LiteralObjectVar._get_all_var_data_without_creating_var({
1594-
k: (None if callable(v) else v)
1595-
for k, v in dataclasses.asdict(value).items()
1594+
k.name: (None if callable(v := getattr(value, k.name)) else v)
1595+
for k in dataclasses.fields(value)
15961596
})
15971597

15981598
if isinstance(value, range):

tests/units/test_state.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from reflex.constants.state import FIELD_MARKER
2929
from reflex.environment import environment
3030
from reflex.event import Event, EventHandler
31+
from reflex.istate.data import HeaderData, _FrozenDictStrStr
3132
from reflex.istate.manager import StateManager
3233
from reflex.istate.manager.disk import StateManagerDisk
3334
from reflex.istate.manager.memory import StateManagerMemory
@@ -925,7 +926,11 @@ def test_get_sid(test_state, router_data):
925926
assert test_state.router.session.session_id == "9fpxSzPb9aFMb4wFAAAH"
926927

927928

928-
def test_get_headers(test_state, router_data, router_data_headers):
929+
def test_get_headers(
930+
test_state: TestState,
931+
router_data: dict[str, str | dict],
932+
router_data_headers: dict[str, str],
933+
):
929934
"""Test getting client headers.
930935
931936
Args:
@@ -936,13 +941,10 @@ def test_get_headers(test_state, router_data, router_data_headers):
936941
print(router_data_headers)
937942
test_state.router = RouterData.from_router_data(router_data)
938943
print(test_state.router.headers)
939-
assert dataclasses.asdict(test_state.router.headers) == {
940-
format.to_snake_case(k): v for k, v in router_data_headers.items()
941-
} | {
942-
"raw_headers": {
943-
"_data": tuple(sorted((k, v) for k, v in router_data_headers.items()))
944-
}
945-
}
944+
assert test_state.router.headers == HeaderData(
945+
**{format.to_snake_case(k): v for k, v in router_data_headers.items()},
946+
raw_headers=_FrozenDictStrStr(**router_data_headers),
947+
)
946948

947949

948950
def test_get_client_ip(test_state, router_data):

0 commit comments

Comments
 (0)