Skip to content

Commit 21f7629

Browse files
authored
Automatic websocket reconnect and reload handling (#5805)
* emit_update: take `token` instead of `sid` This allows the app to be more resilient in the face of websocket reconnects. The event is processed against a token, so there's no reason to maintain websocket affinity for event processing. Whenever the update is ready to send, it will be sent to the current websocket/sid associated. * Automatic websocket reconnect and reload handling * ensureSocketConnected is called when adding events or pumping the queue to trigger an automatic reconnection to the backend * when "reload" event is encountered, trigger a re-hydrate and wait until ALL on_load have finished processing and `is_hydrated` is True before requeue the event that caused the "reload" * Update mock token_to_sid mapping for test * Add disconnect/reconnect test to test_background_task * Remove non-background disconnect/reconnect test It doesn't really work, because the frontend will only process one non-background event at a time, so the disconnect ends up occuring after the event handler is already done.
1 parent 1d9fee6 commit 21f7629

File tree

5 files changed

+123
-36
lines changed

5 files changed

+123
-36
lines changed

reflex/.templates/web/utils/state.js

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ export const connect = async (
531531
) => {
532532
// Get backend URL object from the endpoint.
533533
const endpoint = getBackendURL(EVENTURL);
534+
const on_hydrated_queue = [];
534535

535536
// Create the socket.
536537
socket.current = io(endpoint.href, {
@@ -552,7 +553,17 @@ export const connect = async (
552553

553554
function checkVisibility() {
554555
if (document.visibilityState === "visible") {
555-
if (!socket.current.connected) {
556+
if (!socket.current) {
557+
connect(
558+
socket,
559+
dispatch,
560+
transports,
561+
setConnectErrors,
562+
client_storage,
563+
navigate,
564+
params,
565+
);
566+
} else if (!socket.current.connected) {
556567
console.log("Socket is disconnected, attempting to reconnect ");
557568
socket.current.connect();
558569
} else {
@@ -593,6 +604,7 @@ export const connect = async (
593604

594605
// When the socket disconnects reset the event_processing flag
595606
socket.current.on("disconnect", () => {
607+
socket.current = null; // allow reconnect to occur automatically
596608
event_processing = false;
597609
window.removeEventListener("unload", disconnectTrigger);
598610
window.removeEventListener("beforeunload", disconnectTrigger);
@@ -603,6 +615,14 @@ export const connect = async (
603615
socket.current.on("event", async (update) => {
604616
for (const substate in update.delta) {
605617
dispatch[substate](update.delta[substate]);
618+
// handle events waiting for `is_hydrated`
619+
if (
620+
substate === state_name &&
621+
update.delta[substate]?.is_hydrated_rx_state_
622+
) {
623+
queueEvents(on_hydrated_queue, socket, false, navigate, params);
624+
on_hydrated_queue.length = 0;
625+
}
606626
}
607627
applyClientStorageDelta(client_storage, update.delta);
608628
event_processing = !update.final;
@@ -612,7 +632,8 @@ export const connect = async (
612632
});
613633
socket.current.on("reload", async (event) => {
614634
event_processing = false;
615-
queueEvents([...initialEvents(), event], socket, true, navigate, params);
635+
on_hydrated_queue.push(event);
636+
queueEvents(initialEvents(), socket, true, navigate, params);
616637
});
617638
socket.current.on("new_token", async (new_token) => {
618639
token = new_token;
@@ -774,10 +795,32 @@ export const useEventLoop = (
774795
}
775796
}, [paramsR]);
776797

798+
const ensureSocketConnected = useCallback(async () => {
799+
// only use websockets if state is present and backend is not disabled (reflex cloud).
800+
if (
801+
Object.keys(initialState).length > 1 &&
802+
!isBackendDisabled() &&
803+
!socket.current
804+
) {
805+
// Initialize the websocket connection.
806+
await connect(
807+
socket,
808+
dispatch,
809+
["websocket"],
810+
setConnectErrors,
811+
client_storage,
812+
navigate,
813+
() => params.current,
814+
);
815+
}
816+
}, [socket, dispatch, setConnectErrors, client_storage, navigate, params]);
817+
777818
// Function to add new events to the event queue.
778819
const addEvents = useCallback((events, args, event_actions) => {
779820
const _events = events.filter((e) => e !== undefined && e !== null);
780821

822+
ensureSocketConnected();
823+
781824
if (!(args instanceof Array)) {
782825
args = [args];
783826
}
@@ -870,21 +913,8 @@ export const useEventLoop = (
870913

871914
// Handle socket connect/disconnect.
872915
useEffect(() => {
873-
// only use websockets if state is present and backend is not disabled (reflex cloud).
874-
if (Object.keys(initialState).length > 1 && !isBackendDisabled()) {
875-
// Initialize the websocket connection.
876-
if (!socket.current) {
877-
connect(
878-
socket,
879-
dispatch,
880-
["websocket"],
881-
setConnectErrors,
882-
client_storage,
883-
navigate,
884-
() => params.current,
885-
);
886-
}
887-
}
916+
// Initialize the websocket connection.
917+
ensureSocketConnected();
888918

889919
// Cleanup function.
890920
return () => {
@@ -903,6 +933,7 @@ export const useEventLoop = (
903933
(async () => {
904934
// Process all outstanding events.
905935
while (event_queue.length > 0 && !event_processing) {
936+
await ensureSocketConnected();
906937
await processEvent(socket.current, navigate, () => params.current);
907938
}
908939
})();

reflex/app.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
State,
9898
StateManager,
9999
StateUpdate,
100+
_split_substate_key,
100101
_substate_key,
101102
all_base_state_classes,
102103
code_uses_state_contexts,
@@ -1559,7 +1560,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
15591560
state._clean()
15601561
await self.event_namespace.emit_update(
15611562
update=StateUpdate(delta=delta),
1562-
sid=state.router.session.session_id,
1563+
token=token,
15631564
)
15641565

15651566
def _process_background(
@@ -1599,7 +1600,7 @@ async def _coro():
15991600
# Send the update to the client.
16001601
await self.event_namespace.emit_update(
16011602
update=update,
1602-
sid=state.router.session.session_id,
1603+
token=event.token,
16031604
)
16041605

16051606
task = asyncio.create_task(
@@ -2061,20 +2062,19 @@ def on_disconnect(self, sid: str):
20612062
and console.error(f"Token cleanup error: {t.exception()}")
20622063
)
20632064

2064-
async def emit_update(self, update: StateUpdate, sid: str) -> None:
2065+
async def emit_update(self, update: StateUpdate, token: str) -> None:
20652066
"""Emit an update to the client.
20662067
20672068
Args:
20682069
update: The state update to send.
2069-
sid: The Socket.IO session id.
2070+
token: The client token (tab) associated with the event.
20702071
"""
2071-
if not sid:
2072+
client_token, _ = _split_substate_key(token)
2073+
sid = self.token_to_sid.get(client_token)
2074+
if sid is None:
20722075
# If the sid is None, we are not connected to a client. Prevent sending
20732076
# updates to all clients.
2074-
return
2075-
token = self.sid_to_token.get(sid)
2076-
if token is None:
2077-
console.warn(f"Attempting to send delta to disconnected websocket {sid}")
2077+
console.warn(f"Attempting to send delta to disconnected client {token!r}")
20782078
return
20792079
# Creating a task prevents the update from being blocked behind other coroutines.
20802080
await asyncio.create_task(
@@ -2165,7 +2165,7 @@ async def on_event(self, sid: str, data: Any):
21652165
# Process the events.
21662166
async for update in updates_gen:
21672167
# Emit the update from processing the event.
2168-
await self.emit_update(update=update, sid=sid)
2168+
await self.emit_update(update=update, token=event.token)
21692169

21702170
async def on_ping(self, sid: str):
21712171
"""Event for testing the API endpoint.

reflex/istate/proxy.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,15 @@ def __init__(
7171
state_instance: The state instance to proxy.
7272
parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
7373
"""
74+
from reflex.state import _substate_key
75+
7476
super().__init__(state_instance)
75-
# compile is not relevant to backend logic
7677
self._self_app = prerequisites.get_and_validate_app().app
7778
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
79+
self._self_substate_token = _substate_key(
80+
state_instance.router.session.client_token,
81+
self._self_substate_path,
82+
)
7883
self._self_actx = None
7984
self._self_mutable = False
8085
self._self_actx_lock = asyncio.Lock()
@@ -127,16 +132,9 @@ async def __aenter__(self) -> StateProxy:
127132
msg = "The state is already mutable. Do not nest `async with self` blocks."
128133
raise ImmutableStateError(msg)
129134

130-
from reflex.state import _substate_key
131-
132135
await self._self_actx_lock.acquire()
133136
self._self_actx_lock_holder = current_task
134-
self._self_actx = self._self_app.modify_state(
135-
token=_substate_key(
136-
self.__wrapped__.router.session.client_token,
137-
self._self_substate_path,
138-
)
139-
)
137+
self._self_actx = self._self_app.modify_state(token=self._self_substate_token)
140138
mutable_state = await self._self_actx.__aenter__()
141139
super().__setattr__(
142140
"__wrapped__", mutable_state.get_substate(self._self_substate_path)

tests/integration/test_background_task.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,15 @@ async def yield_in_async_with_self(self):
109109
yield
110110
self.counter += 1
111111

112+
@rx.event(background=True)
113+
async def disconnect_reconnect_background(self):
114+
async with self:
115+
self.counter += 1
116+
yield rx.call_script("socket.disconnect()")
117+
await asyncio.sleep(0.5)
118+
async with self:
119+
self.counter += 1
120+
112121
class OtherState(rx.State):
113122
@rx.event(background=True)
114123
async def get_other_state(self):
@@ -134,6 +143,9 @@ def index() -> rx.Component:
134143
rx.input(
135144
id="token", value=State.router.session.client_token, is_read_only=True
136145
),
146+
rx.input(
147+
id="sid", value=State.router.session.session_id, is_read_only=True
148+
),
137149
rx.hstack(
138150
rx.heading(State.counter, id="counter"),
139151
rx.text(State.counter_async_cv, size="1", id="counter-async-cv"),
@@ -185,6 +197,11 @@ def index() -> rx.Component:
185197
on_click=State.yield_in_async_with_self,
186198
id="yield-in-async-with-self",
187199
),
200+
rx.button(
201+
"Disconnect / Reconnect Background",
202+
on_click=State.disconnect_reconnect_background,
203+
id="disconnect-reconnect-background",
204+
),
188205
rx.button("Reset", on_click=State.reset_counter, id="reset"),
189206
)
190207

@@ -395,3 +412,42 @@ def test_yield_in_async_with_self(
395412

396413
yield_in_async_with_self_button.click()
397414
AppHarness.expect(lambda: counter.text == "2", timeout=5)
415+
416+
417+
@pytest.mark.parametrize(
418+
"button_id",
419+
[
420+
"disconnect-reconnect-background",
421+
],
422+
)
423+
def test_disconnect_reconnect(
424+
background_task: AppHarness,
425+
driver: WebDriver,
426+
token: str,
427+
button_id: str,
428+
):
429+
"""Test that disconnecting and reconnecting works as expected.
430+
431+
Args:
432+
background_task: harness for BackgroundTask app.
433+
driver: WebDriver instance.
434+
token: The token for the connected client.
435+
button_id: The ID of the button to click.
436+
"""
437+
counter = driver.find_element(By.ID, "counter")
438+
button = driver.find_element(By.ID, button_id)
439+
increment_button = driver.find_element(By.ID, "increment")
440+
sid_input = driver.find_element(By.ID, "sid")
441+
sid = background_task.poll_for_value(sid_input, timeout=5)
442+
assert sid is not None
443+
444+
AppHarness.expect(lambda: counter.text == "0", timeout=5)
445+
button.click()
446+
AppHarness.expect(lambda: counter.text == "1", timeout=5)
447+
increment_button.click()
448+
# should get a new sid after the reconnect
449+
assert (
450+
background_task.poll_for_value(sid_input, timeout=5, exp_not_equal=sid) != sid
451+
)
452+
# Final update should come through on the new websocket connection
453+
AppHarness.expect(lambda: counter.text == "3", timeout=5)

tests/units/test_state.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,7 @@ async def test_state_proxy(
20052005
namespace = mock_app.event_namespace
20062006
assert namespace is not None
20072007
namespace.sid_to_token[router_data.session.session_id] = token
2008+
namespace.token_to_sid[token] = router_data.session.session_id
20082009
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
20092010
mock_app.state_manager.states[parent_state.router.session.client_token] = (
20102011
parent_state
@@ -2214,6 +2215,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
22142215
namespace = mock_app.event_namespace
22152216
assert namespace is not None
22162217
namespace.sid_to_token[sid] = token
2218+
namespace.token_to_sid[token] = sid
22172219
mock_app.state_manager.state = mock_app._state = BackgroundTaskState
22182220
async for update in rx.app.process(
22192221
mock_app,

0 commit comments

Comments
 (0)