@@ -215,8 +215,32 @@ def __init__(
215215 # Terminating
216216 self ._terminating_task = None
217217
218- self ._start_serve_responses ()
219- self ._start_close_session_checker ()
218+ async def transition_connecting () -> None :
219+ if self ._state in TerminalStates :
220+ return
221+ self ._state = SessionState .CONNECTING
222+ self ._wait_for_connected .clear ()
223+ await self .ensure_connected ()
224+
225+ async def transition_no_connection () -> None :
226+ if self ._state in TerminalStates :
227+ return
228+ self ._state = SessionState .NO_CONNECTION
229+ if self ._ws :
230+ self ._task_manager .create_task (self ._ws .close ())
231+ self ._ws = None
232+
233+ if self ._retry_connection_callback :
234+ self ._task_manager .create_task (self ._retry_connection_callback ())
235+
236+ await self ._begin_close_session_countdown ()
237+
238+ self ._start_serve_responses (
239+ transition_no_connection = transition_no_connection ,
240+ )
241+ self ._start_close_session_checker (
242+ transition_no_connection = transition_no_connection ,
243+ )
220244 self ._start_buffered_message_sender ()
221245
222246 async def ensure_connected (self ) -> None :
@@ -455,7 +479,9 @@ async def close(self) -> None:
455479 # This will get us GC'd, so this should be the last thing.
456480 await self ._close_session_callback (self )
457481
458- def _start_buffered_message_sender (self ) -> None :
482+ def _start_buffered_message_sender (
483+ self ,
484+ ) -> None :
459485 def commit (msg : TransportMessage ) -> None :
460486 pending = self ._send_buffer .popleft ()
461487 if msg .seq != pending .seq :
@@ -504,42 +530,29 @@ async def block_until_message_available() -> None:
504530 )
505531 )
506532
507- def _start_close_session_checker (self ) -> None :
508- def transition_connecting () -> None :
509- if self ._state in TerminalStates :
510- return
511- self ._state = SessionState .CONNECTING
512- self ._wait_for_connected .clear ()
513-
533+ def _start_close_session_checker (
534+ self ,
535+ transition_no_connection : Callable [[], Awaitable [None ]],
536+ ) -> None :
514537 self ._task_manager .create_task (
515538 _check_to_close_session (
516- self ._transport_options .close_session_check_interval_ms ,
517- lambda : self ._state ,
518- lambda : self ._ws ,
519- transition_connecting = transition_connecting ,
539+ close_session_check_interval_ms = self ._transport_options .close_session_check_interval_ms ,
540+ get_state = lambda : self ._state ,
541+ get_ws = lambda : self ._ws ,
542+ transition_no_connection = transition_no_connection ,
520543 )
521544 )
522545
523- def _start_serve_responses (self ) -> None :
546+ def _start_serve_responses (
547+ self ,
548+ transition_no_connection : Callable [[], Awaitable [None ]],
549+ ) -> None :
524550 def transition_connecting () -> None :
525551 if self ._state in TerminalStates :
526552 return
527553 self ._state = SessionState .CONNECTING
528554 self ._wait_for_connected .clear ()
529555
530- async def transition_no_connection () -> None :
531- if self ._state in TerminalStates :
532- return
533- self ._state = SessionState .NO_CONNECTION
534- if self ._ws :
535- self ._task_manager .create_task (self ._ws .close ())
536- self ._ws = None
537-
538- if self ._retry_connection_callback :
539- self ._task_manager .create_task (self ._retry_connection_callback ())
540-
541- await self ._begin_close_session_countdown ()
542-
543556 def assert_incoming_seq_bookkeeping (
544557 msg_from : str ,
545558 msg_seq : int ,
@@ -939,15 +952,15 @@ async def _check_to_close_session(
939952 close_session_check_interval_ms : float ,
940953 get_state : Callable [[], SessionState ],
941954 get_ws : Callable [[], ClientConnection | None ],
942- transition_connecting : Callable [[], None ],
955+ transition_no_connection : Callable [[], Awaitable [ None ] ],
943956) -> None :
944957 while get_state () not in TerminalStates :
945958 logger .debug ("_check_to_close_session: Checking" )
946959 await asyncio .sleep (close_session_check_interval_ms / 1000 )
947960
948961 if (ws := get_ws ()) and ws .protocol .state is CLOSED :
949962 logger .info ("Websocket is closed, transitioning to connecting" )
950- transition_connecting ()
963+ transition_no_connection ()
951964
952965
953966async def _do_ensure_connected [HandshakeMetadata ](
0 commit comments