Skip to content

Commit 454669d

Browse files
Decouple error channel from backpressure channel
Discovered that it was overloaded, written to by multiple different sources with different semantics.
1 parent 1e206f5 commit 454669d

File tree

1 file changed

+52
-25
lines changed

1 file changed

+52
-25
lines changed

src/replit_river/v2/session.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import nanoid
2323
import websockets.asyncio.client
24-
from aiochannel import Channel, ChannelFull
24+
from aiochannel import Channel, ChannelEmpty, ChannelFull
2525
from aiochannel.errors import ChannelClosed
2626
from opentelemetry.trace import Span, use_span
2727
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
@@ -83,6 +83,9 @@
8383
STREAM_CLOSED_BIT: STREAM_CLOSED_BIT_TYPE = 0b01000
8484

8585

86+
_BackpressuredWaiter: TypeAlias = Callable[[], Awaitable[None]]
87+
88+
8689
class ResultOk(TypedDict):
8790
ok: Literal[True]
8891
payload: Any
@@ -120,7 +123,8 @@ class _IgnoreMessage:
120123

121124
class StreamMeta(TypedDict):
122125
span: Span
123-
error_channel: Channel[None | Exception]
126+
release_backpressured_waiter: Callable[[], None]
127+
error_channel: Channel[Exception]
124128
output: Channel[Any]
125129

126130

@@ -417,6 +421,7 @@ async def close(self, reason: Exception | None = None) -> None:
417421
logger.exception(
418422
"Unable to tell the caller that the session is going away",
419423
)
424+
stream_meta["release_backpressured_waiter"]()
420425
# Before we GC the streams, let's wait for all tasks to be closed gracefully.
421426
await asyncio.gather(
422427
*[stream_meta["output"].join() for stream_meta in self._streams.values()]
@@ -473,7 +478,7 @@ async def commit(msg: TransportMessage) -> None:
473478
# Wake up backpressured writer
474479
stream_meta = self._streams.get(pending.streamId)
475480
if stream_meta:
476-
await stream_meta["error_channel"].put(None)
481+
stream_meta["release_backpressured_waiter"]()
477482

478483
def get_next_pending() -> TransportMessage | None:
479484
if self._send_buffer:
@@ -584,7 +589,7 @@ async def _with_stream(
584589
span: Span,
585590
stream_id: str,
586591
maxsize: int,
587-
) -> AsyncIterator[tuple[Channel[None | Exception], Channel[ResultType]]]:
592+
) -> AsyncIterator[tuple[_BackpressuredWaiter, AsyncIterator[ResultType]]]:
588593
"""
589594
_with_stream
590595
@@ -596,14 +601,36 @@ async def _with_stream(
596601
emitted should call await error_channel.wait() prior to emission.
597602
"""
598603
output: Channel[Any] = Channel(maxsize=maxsize)
599-
error_channel: Channel[None | Exception] = Channel(maxsize=1)
604+
backpressured_waiter_event: asyncio.Event = asyncio.Event()
605+
error_channel: Channel[Exception] = Channel(maxsize=1)
600606
self._streams[stream_id] = {
601607
"span": span,
602608
"error_channel": error_channel,
609+
"release_backpressured_waiter": backpressured_waiter_event.set,
603610
"output": output,
604611
}
612+
613+
async def backpressured_waiter() -> None:
614+
await backpressured_waiter_event.wait()
615+
try:
616+
err = error_channel.get_nowait()
617+
raise err
618+
except (ChannelClosed, ChannelEmpty):
619+
# No errors, off to the next message
620+
pass
621+
622+
async def error_checking_output() -> AsyncIterator[ResultType]:
623+
async for elem in output:
624+
try:
625+
err = error_channel.get_nowait()
626+
raise err
627+
except (ChannelClosed, ChannelEmpty):
628+
# No errors, off to the next message
629+
pass
630+
yield elem
631+
605632
try:
606-
yield (error_channel, output)
633+
yield (backpressured_waiter, error_checking_output())
607634
finally:
608635
stream_meta = self._streams.get(stream_id)
609636
if not stream_meta:
@@ -644,14 +671,16 @@ async def send_rpc[R, A](
644671
span=span,
645672
)
646673

647-
async with self._with_stream(span, stream_id, 1) as (error_channel, output):
674+
async with self._with_stream(span, stream_id, 1) as (
675+
backpressured_waiter,
676+
output,
677+
):
648678
# Handle potential errors during communication
649679
try:
650680
async with asyncio.timeout(timeout.total_seconds()):
651681
# Block for backpressure and emission errors from the ws
652-
if err := await error_channel.get():
653-
raise err
654-
result = await output.get()
682+
await backpressured_waiter()
683+
result = await anext(output)
655684
except asyncio.TimeoutError as e:
656685
await self._send_cancel_stream(
657686
stream_id=stream_id,
@@ -705,15 +734,16 @@ async def send_upload[I, R, A](
705734
span=span,
706735
)
707736

708-
async with self._with_stream(span, stream_id, 1) as (error_channel, output):
737+
async with self._with_stream(span, stream_id, 1) as (
738+
backpressured_waiter,
739+
output,
740+
):
709741
try:
710742
# If this request is not closed and the session is killed, we should
711743
# throw exception here
712744
async for item in request:
713745
# Block for backpressure and emission errors from the ws
714-
if err := await error_channel.get():
715-
raise err
716-
746+
await backpressured_waiter()
717747
try:
718748
payload = request_serializer(item)
719749
except Exception as e:
@@ -753,7 +783,7 @@ async def send_upload[I, R, A](
753783
)
754784

755785
try:
756-
result = await output.get()
786+
result = await anext(output)
757787
except ChannelClosed as e:
758788
raise RiverServiceException(
759789
ERROR_CODE_STREAM_CLOSED,
@@ -856,7 +886,7 @@ async def send_stream[I, R, E, A](
856886
)
857887

858888
async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as (
859-
error_channel,
889+
backpressured_waiter,
860890
output,
861891
):
862892
# Create the encoder task
@@ -871,13 +901,9 @@ async def _encode_stream() -> None:
871901
assert request_serializer, "send_stream missing request_serializer"
872902

873903
async for item in request:
874-
# Block for backpressure and emission errors from the ws
875-
if err := await error_channel.get():
876-
await self._send_close_stream(
877-
stream_id=stream_id,
878-
span=span,
879-
)
880-
raise err
904+
# Block for backpressure (or errors)
905+
await backpressured_waiter()
906+
# If there are any errors so far, raise them
881907
await self._enqueue_message(
882908
stream_id=stream_id,
883909
control_flags=0,
@@ -894,14 +920,15 @@ async def _encode_stream() -> None:
894920
try:
895921
async for result in output:
896922
# Raise as early as we possibly can in case of an emission error
897-
if err := emitter_task.done() and emitter_task.exception():
923+
if emitter_task.done() and (err := emitter_task.exception()):
898924
raise err
899925
if result.get("type") == "CLOSE":
900926
break
901927
if "ok" not in result or not result["ok"]:
902928
yield error_deserializer(result["payload"])
903929
yield response_deserializer(result["payload"])
904-
# ... block the outer function until the emitter is finished emitting.
930+
# ... block the outer function until the emitter is finished emitting,
931+
# possibly raising a terminal exception.
905932
await emitter_task
906933
except Exception as e:
907934
await self._send_cancel_stream(

0 commit comments

Comments
 (0)