@@ -470,7 +470,6 @@ async def close(self) -> None:
470470 # already closing
471471 return
472472 self ._state = SessionState .CLOSING
473- self ._reset_session_close_countdown ()
474473 await self ._task_manager .cancel_all_tasks ()
475474
476475 # TODO: unexpected_close should close stream differently here to
@@ -535,19 +534,21 @@ def _start_close_session_checker(self) -> None:
535534 )
536535
537536 def _start_heartbeat (self ) -> None :
538- async def do_close_websocket () -> None :
537+ async def close_websocket () -> None :
539538 logger .debug (
540539 "do_close called, _state=%r, _ws_unwrapped=%r" ,
541540 self ._state ,
542541 self ._ws_unwrapped ,
543542 )
544543 if self ._ws_unwrapped :
545544 self ._task_manager .create_task (self ._ws_unwrapped .close ())
546- if self ._retry_connection_callback :
547- self ._task_manager .create_task (self ._retry_connection_callback ())
548545 self ._ws_unwrapped = None
546+
547+ if self ._retry_connection_callback :
548+ self ._task_manager .create_task (self ._retry_connection_callback ())
549549 else :
550550 self ._state = SessionState .CLOSING
551+
551552 await self ._begin_close_session_countdown ()
552553
553554 def increment_and_get_heartbeat_misses () -> int :
@@ -561,7 +562,7 @@ def increment_and_get_heartbeat_misses() -> int:
561562 self ._transport_options .heartbeats_until_dead ,
562563 lambda : self ._state ,
563564 lambda : self ._close_session_after_time_secs ,
564- close_websocket = do_close_websocket ,
565+ close_websocket = close_websocket ,
565566 send_message = self .send_message ,
566567 increment_and_get_heartbeat_misses = increment_and_get_heartbeat_misses ,
567568 )
@@ -573,6 +574,10 @@ async def transition_connecting() -> None:
573574
574575 async def connection_interrupted () -> None :
575576 self ._state = SessionState .CONNECTING
577+ if self ._ws_unwrapped :
578+ self ._task_manager .create_task (self ._ws_unwrapped .close ())
579+ self ._ws_unwrapped = None
580+
576581 if self ._retry_connection_callback :
577582 self ._task_manager .create_task (self ._retry_connection_callback ())
578583
0 commit comments