Skip to content

Commit 477348f

Browse files
authored
Remove token/sid associations when server is exiting (#5802)
* Remove token/sid associations when server is exiting This allows existing tokens to reconnect to redis after a hot or cold reload of the app. Otherwise, the old associations for the token remain in place and when the same client reconnects, it is given a new_token, since the requested token is already "taken" in redis. * test_connection_banner: assert that token/sid association removed on shutdown * Re-fetch the token_manager after restarting backend
1 parent 4468e14 commit 477348f

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed

reflex/app_mixins/lifespan.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ async def _run_lifespan_tasks(self, app: Starlette):
6060
for task in running_tasks:
6161
console.debug(f"Canceling lifespan task: {task}")
6262
task.cancel(msg="lifespan_cleanup")
63+
# Disassociate sid / token pairings so they can be reconnected properly.
64+
try:
65+
event_namespace = self.event_namespace # pyright: ignore[reportAttributeAccessIssue]
66+
except AttributeError:
67+
pass
68+
else:
69+
try:
70+
if event_namespace:
71+
await event_namespace._token_manager.disconnect_all()
72+
except Exception as e:
73+
console.error(f"Error during lifespan cleanup: {e}")
6374

6475
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
6576
"""Register a task to run during the lifespan of the app.

reflex/testing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
)
4848
from reflex.utils import console, js_runtimes
4949
from reflex.utils.export import export
50+
from reflex.utils.token_manager import TokenManager
5051
from reflex.utils.types import ASGIApp
5152

5253
try:
@@ -774,6 +775,19 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
774775
self.app_instance._state_manager = app_state_manager
775776
await self.state_manager.close()
776777

778+
def token_manager(self) -> TokenManager:
779+
"""Get the token manager for the app instance.
780+
781+
Returns:
782+
The current token_manager attached to the app's EventNamespace.
783+
"""
784+
assert self.app_instance is not None
785+
app_event_namespace = self.app_instance.event_namespace
786+
assert app_event_namespace is not None
787+
app_token_manager = app_event_namespace._token_manager
788+
assert app_token_manager is not None
789+
return app_token_manager
790+
777791
def poll_for_content(
778792
self,
779793
element: WebElement,

reflex/utils/token_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ def create(cls) -> TokenManager:
6666

6767
return LocalTokenManager()
6868

69+
async def disconnect_all(self):
70+
"""Disconnect all tracked tokens when the server is going down."""
71+
token_sid_pairs: set[tuple[str, str]] = set(self.token_to_sid.items())
72+
token_sid_pairs.update(
73+
((token, sid) for sid, token in self.sid_to_token.items())
74+
)
75+
# Perform the disconnection logic here
76+
for token, sid in token_sid_pairs:
77+
await self.disconnect_token(token, sid)
78+
6979

7080
class LocalTokenManager(TokenManager):
7181
"""Token manager using local in-memory dictionaries (single worker)."""

tests/integration/test_connection_banner.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from reflex import constants
1010
from reflex.environment import environment
11+
from reflex.istate.manager import StateManagerRedis
1112
from reflex.testing import AppHarness, WebDriver
13+
from reflex.utils.token_manager import RedisTokenManager
1214

1315
from .utils import SessionStorage
1416

@@ -127,17 +129,21 @@ def has_cloud_banner(driver: WebDriver) -> bool:
127129
return True
128130

129131

130-
def _assert_token(connection_banner, driver):
132+
def _assert_token(connection_banner, driver) -> str:
131133
"""Poll for backend to be up.
132134
133135
Args:
134136
connection_banner: AppHarness instance.
135137
driver: Selenium webdriver instance.
138+
139+
Returns:
140+
The token if found, raises an assertion error otherwise.
136141
"""
137142
ss = SessionStorage(driver)
138143
assert connection_banner._poll_for(lambda: ss.get("token") is not None), (
139144
"token not found"
140145
)
146+
return ss.get("token")
141147

142148

143149
@pytest.mark.asyncio
@@ -151,9 +157,22 @@ async def test_connection_banner(connection_banner: AppHarness):
151157
assert connection_banner.backend is not None
152158
driver = connection_banner.frontend()
153159

154-
_assert_token(connection_banner, driver)
160+
token = _assert_token(connection_banner, driver)
155161
AppHarness.expect(lambda: not has_error_modal(driver))
156162

163+
# Check that the token association was established.
164+
app_token_manager = connection_banner.token_manager()
165+
assert token in app_token_manager.token_to_sid
166+
sid_before = app_token_manager.token_to_sid[token]
167+
if isinstance(connection_banner.state_manager, StateManagerRedis):
168+
assert isinstance(app_token_manager, RedisTokenManager)
169+
assert (
170+
await connection_banner.state_manager.redis.get(
171+
app_token_manager._get_redis_key(token)
172+
)
173+
== b"1"
174+
)
175+
157176
delay_button = driver.find_element(By.ID, "delay")
158177
increment_button = driver.find_element(By.ID, "increment")
159178
counter_element = driver.find_element(By.ID, "counter")
@@ -176,6 +195,17 @@ async def test_connection_banner(connection_banner: AppHarness):
176195
# Error modal should now be displayed
177196
AppHarness.expect(lambda: has_error_modal(driver))
178197

198+
# The token association should have been removed when the server exited.
199+
assert token not in app_token_manager.token_to_sid
200+
if isinstance(connection_banner.state_manager, StateManagerRedis):
201+
assert isinstance(app_token_manager, RedisTokenManager)
202+
assert (
203+
await connection_banner.state_manager.redis.get(
204+
app_token_manager._get_redis_key(token)
205+
)
206+
is None
207+
)
208+
179209
# Increment the counter with backend down
180210
increment_button.click()
181211
assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
@@ -189,6 +219,20 @@ async def test_connection_banner(connection_banner: AppHarness):
189219
# Banner should be gone now
190220
AppHarness.expect(lambda: not has_error_modal(driver))
191221

222+
# After reconnecting, the token association should be re-established.
223+
app_token_manager = connection_banner.token_manager()
224+
if isinstance(connection_banner.state_manager, StateManagerRedis):
225+
assert isinstance(app_token_manager, RedisTokenManager)
226+
assert (
227+
await connection_banner.state_manager.redis.get(
228+
app_token_manager._get_redis_key(token)
229+
)
230+
== b"1"
231+
)
232+
# Make sure the new connection has a different websocket sid.
233+
sid_after = app_token_manager.token_to_sid[token]
234+
assert sid_before != sid_after
235+
192236
# Count should have incremented after coming back up
193237
assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"
194238

0 commit comments

Comments
 (0)