Skip to content

Commit a20f0bf

Browse files
authored
RedisStateManager.get_state: return the correct state class (#6001)
* RedisStateManager.get_state: return the correct state class Save the originally requested state class as `requested_state_cls` so a subsequent loop variable that was also called `state_cls` doesn't interfere with returning the correct state at the end of the function. * Update _potentially_dirty_states when adding explicit dependency Ensure that fetching a dependency state causes dependent states to also be fetched.
1 parent cb262b6 commit a20f0bf

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

reflex/istate/manager/redis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ async def get_state(
279279
token, state_path = _split_substate_key(token)
280280
if state_path:
281281
# Get the State class associated with the given path.
282-
state_cls = self.state.get_class_substate(state_path)
282+
requested_state_cls = self.state.get_class_substate(state_path)
283283
else:
284284
msg = f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
285285
raise RuntimeError(msg)
@@ -291,7 +291,7 @@ async def get_state(
291291

292292
# Determine which states from the tree need to be fetched.
293293
required_state_classes = sorted(
294-
self._get_required_state_classes(state_cls, subclasses=True)
294+
self._get_required_state_classes(requested_state_cls, subclasses=True)
295295
- {type(s) for s in flat_state_tree.values()},
296296
key=lambda x: x.get_full_name(),
297297
)
@@ -337,7 +337,7 @@ async def get_state(
337337
# the top-level state which should always be fetched or already cached.
338338
if top_level:
339339
return flat_state_tree[self.state.get_full_name()]
340-
return flat_state_tree[state_cls.get_full_name()]
340+
return flat_state_tree[requested_state_cls.get_full_name()]
341341

342342
@override
343343
async def set_state(

reflex/vars/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,12 +2488,18 @@ def add_dependency(self, objclass: type[BaseState], dep: Var):
24882488
var_name = all_var_data.field_name
24892489
if var_name:
24902490
self._static_deps.setdefault(state_name, set()).add(var_name)
2491-
objclass.get_root_state().get_class_substate(
2491+
target_state_class = objclass.get_root_state().get_class_substate(
24922492
state_name
2493-
)._var_dependencies.setdefault(var_name, set()).add((
2493+
)
2494+
target_state_class._var_dependencies.setdefault(
2495+
var_name, set()
2496+
).add((
24942497
objclass.get_full_name(),
24952498
self._name,
24962499
))
2500+
target_state_class._potentially_dirty_states.add(
2501+
objclass.get_full_name()
2502+
)
24972503
return
24982504
msg = (
24992505
"ComputedVar dependencies must be Var instances with a state and "

tests/units/test_state.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
)
5656
from reflex.utils.format import json_dumps
5757
from reflex.utils.token_manager import SocketRecord
58-
from reflex.vars.base import Var, computed_var
58+
from reflex.vars.base import Field, Var, computed_var, field
5959
from tests.units.mock_redis import mock_redis
6060

6161
from .states import GenState
@@ -306,12 +306,12 @@ def test_base_class_vars(test_state):
306306
fields = test_state.get_fields()
307307
cls = type(test_state)
308308

309-
for field in fields:
310-
if field.startswith("_") or field in cls.get_skip_vars():
309+
for field_name in fields:
310+
if field_name.startswith("_") or field_name in cls.get_skip_vars():
311311
continue
312-
prop = getattr(cls, field)
312+
prop = getattr(cls, field_name)
313313
assert isinstance(prop, Var)
314-
assert prop._js_expr.split(".")[-1] == field + FIELD_MARKER
314+
assert prop._js_expr.split(".")[-1] == field_name + FIELD_MARKER
315315

316316
assert cls.num1._var_type is int
317317
assert cls.num2._var_type is float
@@ -4304,6 +4304,8 @@ class OtherState(rx.State):
43044304
state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
43054305
other_state = await state.get_state(OtherState)
43064306
assert comp.State is not None
4307+
# The state should have been pre-cached from the dependency.
4308+
assert comp.State.get_name() in state.substates
43074309
comp_state = await state.get_state(comp.State)
43084310
assert comp_state.dirty_vars == set()
43094311

@@ -4329,3 +4331,35 @@ class SecondCvState(CvMixin, rx.State):
43294331

43304332
assert first_cv is not second_cv
43314333
assert first_cv._static_deps is not second_cv._static_deps
4334+
4335+
4336+
@pytest.mark.asyncio
4337+
async def test_add_dependency_get_state_regression(mock_app: rx.App, token: str):
4338+
"""Ensure that a state class can be fetched separately when it's is explicit dep."""
4339+
4340+
class DataState(rx.State):
4341+
"""A state with a var."""
4342+
4343+
data: Field[list[int]] = field(default_factory=lambda: [1, 2, 3])
4344+
4345+
class StatsState(rx.State):
4346+
"""A state with a computed var depending on DataState."""
4347+
4348+
@rx.var(cache=True)
4349+
async def total(self) -> int:
4350+
data_state = await self.get_state(DataState)
4351+
return sum(data_state.data)
4352+
4353+
StatsState.computed_vars["total"].add_dependency(StatsState, DataState.data)
4354+
4355+
class OtherState(rx.State):
4356+
"""A state that gets DataState."""
4357+
4358+
@rx.event
4359+
async def fetch_data_state(self) -> None:
4360+
print(await self.get_state(DataState))
4361+
4362+
mock_app.state_manager.state = mock_app._state = rx.State
4363+
state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
4364+
other_state = await state.get_state(OtherState)
4365+
await other_state.fetch_data_state() # Should not raise exception.

0 commit comments

Comments
 (0)