Skip to content

Commit 70684e4

Browse files
Communicate handshake errors back to callers as well
1 parent 01b079d commit 70684e4

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

src/replit_river/common_session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ async def buffered_message_sender(
6262
commit: Callable[[TransportMessage], Awaitable[None]],
6363
get_state: Callable[[], SessionState],
6464
) -> None:
65+
"""
66+
buffered_message_sender runs in a task and consumes from a queue, emitting
67+
messages over the websocket as quickly as it can.
68+
69+
One of the design goals is to keep the message queue as short as possible to permit
70+
quickly cancelling streams or acking heartbeats, so to that end it is wise to
71+
incorporate backpressure into the lifecycle of get_next_pending/commit.
72+
"""
73+
6574
our_task = asyncio.current_task()
6675
while our_task and not our_task.cancelling() and not our_task.cancelled():
6776
while get_state() in ConnectingStates:

src/replit_river/v2/session.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,13 @@ def get_next_sent_seq() -> int:
239239
return self._send_buffer[0].seq
240240
return self.seq
241241

242-
def do_close() -> None:
242+
def close_session(reason: Exception | None) -> None:
243243
# Avoid closing twice
244244
if self._terminating_task is None:
245245
# We can't just call self.close() directly because
246246
# we're inside a thread that will eventually be awaited
247247
# during the cleanup procedure.
248-
self._terminating_task = asyncio.create_task(self.close())
248+
self._terminating_task = asyncio.create_task(self.close(reason))
249249

250250
def transition_connecting() -> None:
251251
if self._state in TerminalStates:
@@ -296,7 +296,7 @@ def unbind_connecting_task() -> None:
296296
close_ws_in_background=close_ws_in_background,
297297
transition_connected=transition_connected,
298298
unbind_connecting_task=unbind_connecting_task,
299-
do_close=do_close,
299+
close_session=close_session,
300300
)
301301
)
302302

@@ -433,6 +433,26 @@ async def close(self, reason: Exception | None = None) -> None:
433433
def _start_buffered_message_sender(
434434
self,
435435
) -> None:
436+
"""
437+
Building on buffered_message_sender's documentation, we implement backpressure
438+
per-stream by way of self._streams'
439+
440+
error_channel: Channel[Exception | None]
441+
442+
This is accomplished via the following strategy:
443+
- If buffered_message_sender encounters an error, we transition back to
444+
connecting and attempt to handshake.
445+
446+
If the handshake fails, we close the session with an informative error that
447+
gets emitted to all backpressured client methods.
448+
449+
- Alternately, if buffered_message_sender successfully writes back to the
450+
451+
- Finally, if _recv_from_ws encounters an error (transport or deserialization),
452+
we emit an informative error to close_session which gets emitted to all
453+
backpressured client methods.
454+
"""
455+
436456
async def commit(msg: TransportMessage) -> None:
437457
pending = self._send_buffer.popleft()
438458
if msg.seq != pending.seq:
@@ -935,7 +955,7 @@ async def _do_ensure_connected[HandshakeMetadata](
935955
close_ws_in_background: Callable[[ClientConnection], None],
936956
transition_connected: Callable[[ClientConnection], None],
937957
unbind_connecting_task: Callable[[], None],
938-
do_close: Callable[[], None],
958+
close_session: Callable[[Exception | None], None],
939959
) -> None:
940960
logger.info("Attempting to establish new ws connection")
941961

@@ -1040,15 +1060,16 @@ async def websocket_closed_callback() -> None:
10401060

10411061
logger.debug("river client get handshake response : %r", handshake_response)
10421062
if not handshake_response.status.ok:
1043-
if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH:
1044-
do_close()
1045-
1046-
raise RiverException(
1063+
err = RiverException(
10471064
ERROR_HANDSHAKE,
10481065
f"Handshake failed with code {handshake_response.status.code}: {
10491066
handshake_response.status.reason
10501067
}",
10511068
)
1069+
if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH:
1070+
close_session(err)
1071+
1072+
raise err
10521073

10531074
# We did it! We're connected!
10541075
last_error = None
@@ -1069,7 +1090,7 @@ async def websocket_closed_callback() -> None:
10691090

10701091
if last_error is not None:
10711092
logger.debug("Handshake attempts exhausted, terminating")
1072-
do_close()
1093+
close_session(last_error)
10731094
raise RiverException(
10741095
ERROR_HANDSHAKE,
10751096
f"Failed to create ws after retrying {attempt_count} number of times",

0 commit comments

Comments
 (0)