Skip to content

Commit 56e0c3c

Browse files
Clarifying termination semantics
1 parent 80ba340 commit 56e0c3c

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

src/replit_river/v2/session.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,8 @@ def _start_buffered_message_sender(
494494
Building on buffered_message_sender's documentation, we implement backpressure
495495
per-stream by way of self._streams'
496496
497-
error_channel: Channel[Exception | None]
497+
error_channel: Channel[Exception]
498+
backpressured_waiter: Callable[[], Awaitable[None]]
498499
499500
This is accomplished via the following strategy:
500501
- If buffered_message_sender encounters an error, we transition back to
@@ -506,8 +507,11 @@ def _start_buffered_message_sender(
506507
- Alternately, if buffered_message_sender successfully writes back to the
507508
508509
- Finally, if _recv_from_ws encounters an error (transport or deserialization),
509-
we emit an informative error to close_session which gets emitted to all
510-
backpressured client methods.
510+
it transitions to NO_CONNECTION and defers to the client_transport to
511+
reestablish a connection.
512+
513+
The in-flight messages are still valid, as if we can reconnect to the server
514+
in time, those responses can be marshalled to their respective callbacks.
511515
"""
512516

513517
async def commit(msg: TransportMessage) -> None:
@@ -729,7 +733,8 @@ async def send_rpc[R, A](
729733
# Handle potential errors during communication
730734
try:
731735
async with asyncio.timeout(timeout.total_seconds()):
732-
# Block for backpressure
736+
# Block for backpressure. For an RPC this is trivially true
737+
# but here for consistency with the other methods.
733738
await backpressured_waiter()
734739
# Race output and error channels
735740
raced = await _race2(anext(output), error_channel.get())
@@ -801,9 +806,11 @@ async def send_upload[I, R, A](
801806
# If this request is not closed and the session is killed, we should
802807
# throw exception here
803808
async for item in request:
804-
# Block for backpressure and emission errors from the ws
809+
# Block for backpressure
805810
await backpressured_waiter()
806811
try:
812+
# We check every tick to see whether we've seen an error
813+
# since we're responsible for emitting as quickly as possible.
807814
raise error_channel.get_nowait()
808815
except (ChannelClosed, ChannelEmpty):
809816
# No errors, off to the next message
@@ -900,12 +907,20 @@ async def send_subscription[I, E, A](
900907
)
901908

902909
async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as (
903-
backpressured_waiter,
904910
_,
911+
error_channel,
905912
output,
906913
):
907914
try:
908915
async for item in output:
916+
try:
917+
# We check every tick to see whether we've seen an error
918+
# since we're responsible for emitting as quickly as possible.
919+
raise error_channel.get_nowait()
920+
except (ChannelClosed, ChannelEmpty):
921+
# No errors, off to the next message
922+
pass
923+
909924
if item.get("type") == "CLOSE":
910925
break
911926
if not item.get("ok", False):
@@ -966,9 +981,16 @@ async def _encode_stream(
966981
request: AsyncIterable[R], request_serializer: Callable[[R], Any]
967982
) -> None:
968983
async for item in request:
969-
# Block for backpressure (or errors)
984+
# Block for backpressure
970985
await backpressured_waiter()
971-
# If there are any errors so far, raise them
986+
try:
987+
# We check every tick to see whether we've seen an error
988+
# since we're responsible for emitting as quickly as possible.
989+
raise error_channel.get_nowait()
990+
except (ChannelClosed, ChannelEmpty):
991+
# No errors, off to the next message
992+
pass
993+
972994
await self._enqueue_message(
973995
stream_id=stream_id,
974996
control_flags=0,

0 commit comments

Comments
 (0)