@@ -239,13 +239,13 @@ def get_next_sent_seq() -> int:
239239 return self ._send_buffer [0 ].seq
240240 return self .seq
241241
242- def do_close ( ) -> None :
242+ def close_session ( reason : Exception | None ) -> None :
243243 # Avoid closing twice
244244 if self ._terminating_task is None :
245245 # We can't just call self.close() directly because
246246 # we're inside a thread that will eventually be awaited
247247 # during the cleanup procedure.
248- self ._terminating_task = asyncio .create_task (self .close ())
248+ self ._terminating_task = asyncio .create_task (self .close (reason ))
249249
250250 def transition_connecting () -> None :
251251 if self ._state in TerminalStates :
@@ -296,7 +296,7 @@ def unbind_connecting_task() -> None:
296296 close_ws_in_background = close_ws_in_background ,
297297 transition_connected = transition_connected ,
298298 unbind_connecting_task = unbind_connecting_task ,
299- do_close = do_close ,
299+ close_session = close_session ,
300300 )
301301 )
302302
@@ -433,6 +433,26 @@ async def close(self, reason: Exception | None = None) -> None:
433433 def _start_buffered_message_sender (
434434 self ,
435435 ) -> None :
436+ """
437+ Building on buffered_message_sender's documentation, we implement backpressure
438+ per-stream by way of self._streams'
439+
440+ error_channel: Channel[Exception | None]
441+
442+ This is accomplished via the following strategy:
443+ - If buffered_message_sender encounters an error, we transition back to
444+ connecting and attempt to handshake.
445+
446+ If the handshake fails, we close the session with an informative error that
447+ gets emitted to all backpressured client methods.
448+
449+ - Alternately, if buffered_message_sender successfully writes back to the
450+
451+ - Finally, if _recv_from_ws encounters an error (transport or deserialization),
452+ we emit an informative error to close_session which gets emitted to all
453+ backpressured client methods.
454+ """
455+
436456 async def commit (msg : TransportMessage ) -> None :
437457 pending = self ._send_buffer .popleft ()
438458 if msg .seq != pending .seq :
@@ -935,7 +955,7 @@ async def _do_ensure_connected[HandshakeMetadata](
935955 close_ws_in_background : Callable [[ClientConnection ], None ],
936956 transition_connected : Callable [[ClientConnection ], None ],
937957 unbind_connecting_task : Callable [[], None ],
938- do_close : Callable [[], None ],
958+ close_session : Callable [[Exception | None ], None ],
939959) -> None :
940960 logger .info ("Attempting to establish new ws connection" )
941961
@@ -1040,15 +1060,16 @@ async def websocket_closed_callback() -> None:
10401060
10411061 logger .debug ("river client get handshake response : %r" , handshake_response )
10421062 if not handshake_response .status .ok :
1043- if handshake_response .status .code == ERROR_CODE_SESSION_STATE_MISMATCH :
1044- do_close ()
1045-
1046- raise RiverException (
1063+ err = RiverException (
10471064 ERROR_HANDSHAKE ,
10481065 f"Handshake failed with code { handshake_response .status .code } : {
10491066 handshake_response .status .reason
10501067 } " ,
10511068 )
1069+ if handshake_response .status .code == ERROR_CODE_SESSION_STATE_MISMATCH :
1070+ close_session (err )
1071+
1072+ raise err
10521073
10531074 # We did it! We're connected!
10541075 last_error = None
@@ -1069,7 +1090,7 @@ async def websocket_closed_callback() -> None:
10691090
10701091 if last_error is not None :
10711092 logger .debug ("Handshake attempts exhausted, terminating" )
1072- do_close ( )
1093+ close_session ( last_error )
10731094 raise RiverException (
10741095 ERROR_HANDSHAKE ,
10751096 f"Failed to create ws after retrying { attempt_count } number of times" ,
0 commit comments