88
99from reflex import constants
1010from reflex .environment import environment
11+ from reflex .istate .manager import StateManagerRedis
1112from reflex .testing import AppHarness , WebDriver
13+ from reflex .utils .token_manager import RedisTokenManager
1214
1315from .utils import SessionStorage
1416
@@ -127,17 +129,21 @@ def has_cloud_banner(driver: WebDriver) -> bool:
127129 return True
128130
129131
130- def _assert_token (connection_banner , driver ):
132+ def _assert_token (connection_banner , driver ) -> str :
131133 """Poll for backend to be up.
132134
133135 Args:
134136 connection_banner: AppHarness instance.
135137 driver: Selenium webdriver instance.
138+
139+ Returns:
140+ The token if found, raises an assertion error otherwise.
136141 """
137142 ss = SessionStorage (driver )
138143 assert connection_banner ._poll_for (lambda : ss .get ("token" ) is not None ), (
139144 "token not found"
140145 )
146+ return ss .get ("token" )
141147
142148
143149@pytest .mark .asyncio
@@ -151,9 +157,25 @@ async def test_connection_banner(connection_banner: AppHarness):
151157 assert connection_banner .backend is not None
152158 driver = connection_banner .frontend ()
153159
154- _assert_token (connection_banner , driver )
160+ token = _assert_token (connection_banner , driver )
155161 AppHarness .expect (lambda : not has_error_modal (driver ))
156162
163+ # Check that the token association was established.
164+ app_event_namespace = connection_banner .app_instance .event_namespace
165+ assert app_event_namespace is not None
166+ app_token_manager = app_event_namespace ._token_manager
167+ assert app_token_manager is not None
168+ assert token in app_token_manager .token_to_sid
169+ sid_before = app_token_manager .token_to_sid [token ]
170+ if isinstance (connection_banner .state_manager , StateManagerRedis ):
171+ assert isinstance (app_token_manager , RedisTokenManager )
172+ assert (
173+ await connection_banner .state_manager .redis .get (
174+ app_token_manager ._get_redis_key (token )
175+ )
176+ == b"1"
177+ )
178+
157179 delay_button = driver .find_element (By .ID , "delay" )
158180 increment_button = driver .find_element (By .ID , "increment" )
159181 counter_element = driver .find_element (By .ID , "counter" )
@@ -176,6 +198,17 @@ async def test_connection_banner(connection_banner: AppHarness):
176198 # Error modal should now be displayed
177199 AppHarness .expect (lambda : has_error_modal (driver ))
178200
201+ # The token association should have been removed when the server exited.
202+ assert token not in app_token_manager .token_to_sid
203+ if isinstance (connection_banner .state_manager , StateManagerRedis ):
204+ assert isinstance (app_token_manager , RedisTokenManager )
205+ assert (
206+ await connection_banner .state_manager .redis .get (
207+ app_token_manager ._get_redis_key (token )
208+ )
209+ is None
210+ )
211+
179212 # Increment the counter with backend down
180213 increment_button .click ()
181214 assert connection_banner .poll_for_value (counter_element , exp_not_equal = "0" ) == "1"
@@ -189,6 +222,19 @@ async def test_connection_banner(connection_banner: AppHarness):
189222 # Banner should be gone now
190223 AppHarness .expect (lambda : not has_error_modal (driver ))
191224
225+ # After reconnecting, the token association should be re-established.
226+ if isinstance (connection_banner .state_manager , StateManagerRedis ):
227+ assert isinstance (app_token_manager , RedisTokenManager )
228+ assert (
229+ await connection_banner .state_manager .redis .get (
230+ app_token_manager ._get_redis_key (token )
231+ )
232+ == b"1"
233+ )
234+ # Make sure the new connection has a different websocket sid.
235+ sid_after = app_token_manager .token_to_sid [token ]
236+ assert sid_before != sid_after
237+
192238 # Count should have incremented after coming back up
193239 assert connection_banner .poll_for_value (counter_element , exp_not_equal = "1" ) == "2"
194240
0 commit comments