@@ -248,12 +248,22 @@ def get_next_sent_seq() -> int:
248248 return self .seq
249249
250250 def close_session (reason : Exception | None ) -> None :
251+ # If we're already closing, just let whoever's currently doing it handle it.
252+ if self ._state in TerminalStates :
253+ return
254+
251255 # Avoid closing twice
252256 if self ._terminating_task is None :
257+ current_state = self ._state
258+ self ._state = SessionState .CLOSING
259+
253260 # We can't just call self.close() directly because
254261 # we're inside a thread that will eventually be awaited
255262 # during the cleanup procedure.
256- self ._terminating_task = asyncio .create_task (self .close (reason ))
263+
264+ self ._terminating_task = asyncio .create_task (
265+ self .close (reason , current_state = current_state ),
266+ )
257267
258268 def transition_connecting () -> None :
259269 if self ._state in TerminalStates :
@@ -301,6 +311,7 @@ def unbind_connecting_task() -> None:
301311 get_next_sent_seq = get_next_sent_seq ,
302312 get_current_ack = lambda : self .ack ,
303313 get_current_time = self ._get_current_time ,
314+ get_state = lambda : self ._state ,
304315 transition_connecting = transition_connecting ,
305316 close_ws_in_background = close_ws_in_background ,
306317 transition_connected = transition_connected ,
@@ -385,12 +396,12 @@ async def _enqueue_message(
385396 # Wake up buffered_message_sender
386397 self ._process_messages .set ()
387398
388- async def close (self , reason : Exception | None = None ) -> None :
399+ async def close (self , reason : Exception | None = None , current_state : SessionState | None = None ) -> None :
389400 """Close the session and all associated streams."""
390401 logger .info (
391402 f"{ self .session_id } closing session to { self ._server_id } , ws: { self ._ws } "
392403 )
393- if self ._state in TerminalStates :
404+ if ( current_state or self ._state ) in TerminalStates :
394405 # already closing
395406 return
396407 self ._state = SessionState .CLOSING
@@ -987,6 +998,7 @@ async def _do_ensure_connected[HandshakeMetadata](
987998 get_current_time : Callable [[], Awaitable [float ]],
988999 get_next_sent_seq : Callable [[], int ],
9891000 get_current_ack : Callable [[], int ],
1001+ get_state : Callable [[], SessionState ],
9901002 transition_connecting : Callable [[], None ],
9911003 close_ws_in_background : Callable [[ClientConnection ], None ],
9921004 transition_connected : Callable [[ClientConnection ], None ],
@@ -998,6 +1010,10 @@ async def _do_ensure_connected[HandshakeMetadata](
9981010 last_error : Exception | None = None
9991011 attempt_count = 0
10001012 while rate_limiter .has_budget (client_id ):
1013+ if (state := get_state ()) in TerminalStates or state in ActiveStates :
1014+ logger .info (f"_do_ensure_connected stopping due to state={ state } " )
1015+ break
1016+
10011017 if attempt_count > 0 :
10021018 logger .info (f"Retrying build handshake number { attempt_count } times" )
10031019 attempt_count += 1
@@ -1051,40 +1067,40 @@ async def websocket_closed_callback() -> None:
10511067 handshake_deadline_ms = (
10521068 await get_current_time () + transport_options .handshake_timeout_ms
10531069 )
1054- while True :
1055- if await get_current_time () >= handshake_deadline_ms :
1056- raise RiverException (
1057- ERROR_HANDSHAKE ,
1058- "Handshake response timeout, closing connection" ,
1059- )
1060- try :
1061- data = await ws .recv (decode = False )
1062- except ConnectionClosed as e :
1063- logger .debug (
1064- "_do_ensure_connected: Connection closed during waiting "
1065- "for handshake response" ,
1066- exc_info = True ,
1067- )
1068- raise RiverException (
1069- ERROR_HANDSHAKE ,
1070- "Handshake failed, conn closed while waiting for response" ,
1071- ) from e
10721070
1073- try :
1074- response_msg = parse_transport_msg (data )
1075- if isinstance (response_msg , str ):
1076- logger .debug (
1077- "_do_ensure_connected: Ignoring transport message" ,
1078- exc_info = True ,
1079- )
1080- continue
1071+ if await get_current_time () >= handshake_deadline_ms :
1072+ raise RiverException (
1073+ ERROR_HANDSHAKE ,
1074+ "Handshake response timeout, closing connection" ,
1075+ )
10811076
1082- break
1083- except InvalidMessageException as e :
1084- raise RiverException (
1085- ERROR_HANDSHAKE ,
1086- "Got invalid transport message, closing connection" ,
1087- ) from e
1077+ try :
1078+ data = await ws .recv (decode = False )
1079+ except ConnectionClosed as e :
1080+ logger .debug (
1081+ "_do_ensure_connected: Connection closed during waiting "
1082+ "for handshake response" ,
1083+ exc_info = True ,
1084+ )
1085+ raise RiverException (
1086+ ERROR_HANDSHAKE ,
1087+ "Handshake failed, conn closed while waiting for response" ,
1088+ ) from e
1089+
1090+ try :
1091+ response_msg = parse_transport_msg (data )
1092+ except InvalidMessageException as e :
1093+ raise RiverException (
1094+ ERROR_HANDSHAKE ,
1095+ "Got invalid transport message, closing connection" ,
1096+ ) from e
1097+
1098+ if isinstance (response_msg , str ):
1099+ raise RiverException (
1100+ ERROR_HANDSHAKE ,
1101+ "Handshake failed, received a raw string message while waiting "
1102+ "for a handshake response" ,
1103+ )
10881104
10891105 try :
10901106 handshake_response = ControlMessageHandshakeResponse (
@@ -1105,6 +1121,7 @@ async def websocket_closed_callback() -> None:
11051121 } " ,
11061122 )
11071123 if handshake_response .status .code == ERROR_CODE_SESSION_STATE_MISMATCH :
1124+ # A session state mismatch is unrecoverable. Terminate immediately.
11081125 close_session (err )
11091126
11101127 raise err
0 commit comments