Skip to content

Commit bb1d8ec

Browse files
committed
AppHarness: call state_manager.close() for all state managers
1 parent 53c5cc0 commit bb1d8ec

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

reflex/testing.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,9 @@ def _get_backend_shutdown_handler(self):
305305

306306
async def _shutdown(*args, **kwargs) -> None:
307307
# ensure redis is closed before event loop
308-
if self.app_instance is not None and isinstance(
309-
self.app_instance._state_manager, StateManagerRedis
308+
if (
309+
self.app_instance is not None
310+
and self.app_instance._state_manager is not None
310311
):
311312
with contextlib.suppress(ValueError):
312313
await self.app_instance._state_manager.close()
@@ -358,6 +359,12 @@ async def _reset_backend_state_manager(self):
358359
Raises:
359360
RuntimeError: when the state manager cannot be reset
360361
"""
362+
if (
363+
self.app_instance is not None
364+
and self.app_instance._state_manager is not None
365+
):
366+
with contextlib.suppress(RuntimeError):
367+
await self.app_instance._state_manager.close()
361368
if (
362369
self.app_instance is not None
363370
and isinstance(
@@ -366,8 +373,6 @@ async def _reset_backend_state_manager(self):
366373
)
367374
and self.app_instance._state is not None
368375
):
369-
with contextlib.suppress(RuntimeError):
370-
await self.app_instance._state_manager.close()
371376
self.app_instance._state_manager = StateManagerRedis.create(
372377
state=self.app_instance._state,
373378
)
@@ -716,8 +721,7 @@ async def get_state(self, token: str) -> BaseState:
716721
try:
717722
return await self.state_manager.get_state(token)
718723
finally:
719-
if isinstance(self.state_manager, StateManagerRedis):
720-
await self.state_manager.close()
724+
await self.state_manager.close()
721725

722726
async def set_state(self, token: str, **kwargs) -> None:
723727
"""Set the state associated with the given token.
@@ -738,8 +742,7 @@ async def set_state(self, token: str, **kwargs) -> None:
738742
try:
739743
await self.state_manager.set_state(token, state)
740744
finally:
741-
if isinstance(self.state_manager, StateManagerRedis):
742-
await self.state_manager.close()
745+
await self.state_manager.close()
743746

744747
@contextlib.asynccontextmanager
745748
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
@@ -769,9 +772,8 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
769772
async with self.app_instance.modify_state(token) as state:
770773
yield state
771774
finally:
772-
if isinstance(self.state_manager, StateManagerRedis):
773-
self.app_instance._state_manager = app_state_manager
774-
await self.state_manager.close()
775+
self.app_instance._state_manager = app_state_manager
776+
await self.state_manager.close()
775777

776778
def token_manager(self) -> TokenManager:
777779
"""Get the token manager for the app instance.

0 commit comments

Comments
 (0)