Skip to content

Commit 77594bd

Browse files
committed
[HOS-333] Send a "reload" message to the frontend after state expiry (#4442)
* Unit test updates * test_client_storage: simulate backend state expiry * [HOS-333] Send a "reload" message to the frontend after state expiry 1. a state instance expires on the backing store 2. frontend attempts to process an event against the expired token and gets a fresh instance of the state without router_data set 3. backend sends a "reload" message on the websocket containing the event and immediately stops processing 4. in response to the "reload" message, frontend sends [hydrate, update client storage, on_load, <previous_event>] This allows the frontend and backend to re-syncronize on the state of the app before continuing to process regular events. If the event in (2) is a special hydrate event, then it is processed normally by the middleware and the "reload" logic is skipped since this indicates an initial load or a browser refresh. * unit tests working with redis
1 parent e4ccba7 commit 77594bd

File tree

6 files changed

+150
-6
lines changed

6 files changed

+150
-6
lines changed

reflex/.templates/web/utils/state.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,10 @@ export const connect = async (
454454
queueEvents(update.events, socket);
455455
}
456456
});
457+
socket.current.on("reload", async (event) => {
458+
event_processing = false;
459+
queueEvents([...initialEvents(), JSON5.parse(event)], socket);
460+
})
457461

458462
document.addEventListener("visibilitychange", checkVisibility);
459463
};

reflex/app.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
EventSpec,
7474
EventType,
7575
IndividualEventType,
76+
get_hydrate_event,
7677
window_alert,
7778
)
7879
from reflex.model import Model, get_db_status
@@ -1259,6 +1260,21 @@ async def process(
12591260
)
12601261
# Get the state for the session exclusively.
12611262
async with app.state_manager.modify_state(event.substate_token) as state:
1263+
# When this is a brand new instance of the state, signal the
1264+
# frontend to reload before processing it.
1265+
if (
1266+
not state.router_data
1267+
and event.name != get_hydrate_event(state)
1268+
and app.event_namespace is not None
1269+
):
1270+
await asyncio.create_task(
1271+
app.event_namespace.emit(
1272+
"reload",
1273+
data=format.json_dumps(event),
1274+
to=sid,
1275+
)
1276+
)
1277+
return
12621278
# re-assign only when the value is different
12631279
if state.router_data != router_data:
12641280
# assignment will recurse into substates and force recalculation of

reflex/state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,6 +1959,9 @@ def _update_was_touched(self):
19591959
if var in self.base_vars or var in self._backend_vars:
19601960
self._was_touched = True
19611961
break
1962+
if var == constants.ROUTER_DATA and self.parent_state is None:
1963+
self._was_touched = True
1964+
break
19621965

19631966
def _get_was_touched(self) -> bool:
19641967
"""Check current dirty_vars and flag to determine if state instance was modified.

tests/integration/test_client_storage.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
from selenium.webdriver.common.by import By
1111
from selenium.webdriver.remote.webdriver import WebDriver
1212

13+
from reflex.state import (
14+
State,
15+
StateManagerDisk,
16+
StateManagerMemory,
17+
StateManagerRedis,
18+
_substate_key,
19+
)
1320
from reflex.testing import AppHarness
1421

1522
from . import utils
@@ -74,7 +81,7 @@ def index():
7481
return rx.fragment(
7582
rx.input(
7683
value=ClientSideState.router.session.client_token,
77-
is_read_only=True,
84+
read_only=True,
7885
id="token",
7986
),
8087
rx.input(
@@ -604,6 +611,110 @@ def set_sub_sub(var: str, value: str):
604611
assert s2.text == "s2 value"
605612
assert s3.text == "s3 value"
606613

614+
# Simulate state expiration
615+
if isinstance(client_side.state_manager, StateManagerRedis):
616+
await client_side.state_manager.redis.delete(
617+
_substate_key(token, State.get_full_name())
618+
)
619+
await client_side.state_manager.redis.delete(_substate_key(token, state_name))
620+
await client_side.state_manager.redis.delete(
621+
_substate_key(token, sub_state_name)
622+
)
623+
await client_side.state_manager.redis.delete(
624+
_substate_key(token, sub_sub_state_name)
625+
)
626+
elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)):
627+
del client_side.state_manager.states[token]
628+
if isinstance(client_side.state_manager, StateManagerDisk):
629+
client_side.state_manager.token_expiration = 0
630+
client_side.state_manager._purge_expired_states()
631+
632+
# Ensure the state is gone (not hydrated)
633+
async def poll_for_not_hydrated():
634+
state = await client_side.get_state(_substate_key(token or "", state_name))
635+
return not state.is_hydrated
636+
637+
assert await AppHarness._poll_for_async(poll_for_not_hydrated)
638+
639+
# Trigger event to get a new instance of the state since the old was expired.
640+
state_var_input = driver.find_element(By.ID, "state_var")
641+
state_var_input.send_keys("re-triggering")
642+
643+
# get new references to all cookie and local storage elements (again)
644+
c1 = driver.find_element(By.ID, "c1")
645+
c2 = driver.find_element(By.ID, "c2")
646+
c3 = driver.find_element(By.ID, "c3")
647+
c4 = driver.find_element(By.ID, "c4")
648+
c5 = driver.find_element(By.ID, "c5")
649+
c6 = driver.find_element(By.ID, "c6")
650+
c7 = driver.find_element(By.ID, "c7")
651+
l1 = driver.find_element(By.ID, "l1")
652+
l2 = driver.find_element(By.ID, "l2")
653+
l3 = driver.find_element(By.ID, "l3")
654+
l4 = driver.find_element(By.ID, "l4")
655+
s1 = driver.find_element(By.ID, "s1")
656+
s2 = driver.find_element(By.ID, "s2")
657+
s3 = driver.find_element(By.ID, "s3")
658+
c1s = driver.find_element(By.ID, "c1s")
659+
l1s = driver.find_element(By.ID, "l1s")
660+
s1s = driver.find_element(By.ID, "s1s")
661+
662+
assert c1.text == "c1 value"
663+
assert c2.text == "c2 value"
664+
assert c3.text == "" # temporary cookie expired after reset state!
665+
assert c4.text == "c4 value"
666+
assert c5.text == "c5 value"
667+
assert c6.text == "c6 value"
668+
assert c7.text == "c7 value"
669+
assert l1.text == "l1 value"
670+
assert l2.text == "l2 value"
671+
assert l3.text == "l3 value"
672+
assert l4.text == "l4 value"
673+
assert s1.text == "s1 value"
674+
assert s2.text == "s2 value"
675+
assert s3.text == "s3 value"
676+
assert c1s.text == "c1s value"
677+
assert l1s.text == "l1s value"
678+
assert s1s.text == "s1s value"
679+
680+
# Get the backend state and ensure the values are still set
681+
async def get_sub_state():
682+
root_state = await client_side.get_state(
683+
_substate_key(token or "", sub_state_name)
684+
)
685+
state = root_state.substates[client_side.get_state_name("_client_side_state")]
686+
sub_state = state.substates[
687+
client_side.get_state_name("_client_side_sub_state")
688+
]
689+
return sub_state
690+
691+
async def poll_for_c1_set():
692+
sub_state = await get_sub_state()
693+
return sub_state.c1 == "c1 value"
694+
695+
assert await AppHarness._poll_for_async(poll_for_c1_set)
696+
sub_state = await get_sub_state()
697+
assert sub_state.c1 == "c1 value"
698+
assert sub_state.c2 == "c2 value"
699+
assert sub_state.c3 == ""
700+
assert sub_state.c4 == "c4 value"
701+
assert sub_state.c5 == "c5 value"
702+
assert sub_state.c6 == "c6 value"
703+
assert sub_state.c7 == "c7 value"
704+
assert sub_state.l1 == "l1 value"
705+
assert sub_state.l2 == "l2 value"
706+
assert sub_state.l3 == "l3 value"
707+
assert sub_state.l4 == "l4 value"
708+
assert sub_state.s1 == "s1 value"
709+
assert sub_state.s2 == "s2 value"
710+
assert sub_state.s3 == "s3 value"
711+
sub_sub_state = sub_state.substates[
712+
client_side.get_state_name("_client_side_sub_sub_state")
713+
]
714+
assert sub_sub_state.c1s == "c1s value"
715+
assert sub_sub_state.l1s == "l1s value"
716+
assert sub_sub_state.s1s == "s1s value"
717+
607718
# clear the cookie jar and local storage, ensure state reset to default
608719
driver.delete_all_cookies()
609720
local_storage.clear()

tests/units/test_app.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,8 +1007,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
10071007
substate_token = _substate_key(token, DynamicState)
10081008
sid = "mock_sid"
10091009
client_ip = "127.0.0.1"
1010-
state = await app.state_manager.get_state(substate_token)
1011-
assert state.dynamic == ""
1010+
async with app.state_manager.modify_state(substate_token) as state:
1011+
state.router_data = {"simulate": "hydrated"}
1012+
assert state.dynamic == ""
10121013
exp_vals = ["foo", "foobar", "baz"]
10131014

10141015
def _event(name, val, **kwargs):
@@ -1180,13 +1181,16 @@ async def test_process_events(mocker, token: str):
11801181
"ip": "127.0.0.1",
11811182
}
11821183
app = App(state=GenState)
1184+
11831185
mocker.patch.object(app, "_postprocess", AsyncMock())
11841186
event = Event(
11851187
token=token,
11861188
name=f"{GenState.get_name()}.go",
11871189
payload={"c": 5},
11881190
router_data=router_data,
11891191
)
1192+
async with app.state_manager.modify_state(event.substate_token) as state:
1193+
state.router_data = {"simulate": "hydrated"}
11901194

11911195
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
11921196
pass

tests/units/test_state.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,10 @@ class BackgroundTaskState(BaseState):
19821982
order: List[str] = []
19831983
dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
19841984

1985+
def __init__(self, **kwargs): # noqa: D107
1986+
super().__init__(**kwargs)
1987+
self.router_data = {"simulate": "hydrate"}
1988+
19851989
@rx.var
19861990
def computed_order(self) -> List[str]:
19871991
"""Get the order as a computed var.
@@ -2732,7 +2736,7 @@ class BaseFieldSetterState(BaseState):
27322736
assert "c2" in bfss.dirty_vars
27332737

27342738

2735-
def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
2739+
def exp_is_hydrated(state: BaseState, is_hydrated: bool = True) -> Dict[str, Any]:
27362740
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
27372741
27382742
Args:
@@ -2811,7 +2815,8 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
28112815
app = app_module_mock.app = App(
28122816
state=State, load_events={"index": [test_state.test_handler]}
28132817
)
2814-
state = State()
2818+
async with app.state_manager.modify_state(_substate_key(token, State)) as state:
2819+
state.router_data = {"simulate": "hydrate"}
28152820

28162821
updates = []
28172822
async for update in rx.app.process(
@@ -2858,7 +2863,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
28582863
state=State,
28592864
load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
28602865
)
2861-
state = State()
2866+
async with app.state_manager.modify_state(_substate_key(token, State)) as state:
2867+
state.router_data = {"simulate": "hydrate"}
28622868

28632869
updates = []
28642870
async for update in rx.app.process(

0 commit comments

Comments
 (0)