@@ -305,8 +305,9 @@ def _get_backend_shutdown_handler(self):
305305
306306 async def _shutdown (* args , ** kwargs ) -> None :
307307 # ensure redis is closed before event loop
308- if self .app_instance is not None and isinstance (
309- self .app_instance ._state_manager , StateManagerRedis
308+ if (
309+ self .app_instance is not None
310+ and self .app_instance ._state_manager is not None
310311 ):
311312 with contextlib .suppress (ValueError ):
312313 await self .app_instance ._state_manager .close ()
@@ -358,6 +359,12 @@ async def _reset_backend_state_manager(self):
358359 Raises:
359360 RuntimeError: when the state manager cannot be reset
360361 """
362+ if (
363+ self .app_instance is not None
364+ and self .app_instance ._state_manager is not None
365+ ):
366+ with contextlib .suppress (RuntimeError ):
367+ await self .app_instance ._state_manager .close ()
361368 if (
362369 self .app_instance is not None
363370 and isinstance (
@@ -366,8 +373,6 @@ async def _reset_backend_state_manager(self):
366373 )
367374 and self .app_instance ._state is not None
368375 ):
369- with contextlib .suppress (RuntimeError ):
370- await self .app_instance ._state_manager .close ()
371376 self .app_instance ._state_manager = StateManagerRedis .create (
372377 state = self .app_instance ._state ,
373378 )
@@ -716,8 +721,7 @@ async def get_state(self, token: str) -> BaseState:
716721 try :
717722 return await self .state_manager .get_state (token )
718723 finally :
719- if isinstance (self .state_manager , StateManagerRedis ):
720- await self .state_manager .close ()
724+ await self .state_manager .close ()
721725
722726 async def set_state (self , token : str , ** kwargs ) -> None :
723727 """Set the state associated with the given token.
@@ -738,8 +742,7 @@ async def set_state(self, token: str, **kwargs) -> None:
738742 try :
739743 await self .state_manager .set_state (token , state )
740744 finally :
741- if isinstance (self .state_manager , StateManagerRedis ):
742- await self .state_manager .close ()
745+ await self .state_manager .close ()
743746
744747 @contextlib .asynccontextmanager
745748 async def modify_state (self , token : str ) -> AsyncIterator [BaseState ]:
@@ -769,9 +772,8 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
769772 async with self .app_instance .modify_state (token ) as state :
770773 yield state
771774 finally :
772- if isinstance (self .state_manager , StateManagerRedis ):
773- self .app_instance ._state_manager = app_state_manager
774- await self .state_manager .close ()
775+ self .app_instance ._state_manager = app_state_manager
776+ await self .state_manager .close ()
775777
776778 def token_manager (self ) -> TokenManager :
777779 """Get the token manager for the app instance.
0 commit comments