Skip to content

Commit 27e8e46

Browse files
Moving Exception over to a dedicated channel
1 parent 7bcc717 commit 27e8e46

File tree

1 file changed

+107
-31
lines changed

1 file changed

+107
-31
lines changed

src/replit_river/v2/session.py

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,9 @@ async def _with_stream(
622622
span: Span,
623623
stream_id: str,
624624
maxsize: int,
625-
) -> AsyncIterator[tuple[_BackpressuredWaiter, AsyncIterator[ResultType]]]:
625+
) -> AsyncIterator[
626+
tuple[_BackpressuredWaiter, Channel[Exception], AsyncIterator[ResultType]]
627+
]:
626628
"""
627629
_with_stream
628630
@@ -663,7 +665,7 @@ async def error_checking_output() -> AsyncIterator[ResultType]:
663665
yield elem
664666

665667
try:
666-
yield (backpressured_waiter, error_checking_output())
668+
yield (backpressured_waiter, error_channel, error_checking_output())
667669
finally:
668670
stream_meta = self._streams.get(stream_id)
669671
if not stream_meta:
@@ -706,14 +708,16 @@ async def send_rpc[R, A](
706708

707709
async with self._with_stream(span, stream_id, 1) as (
708710
backpressured_waiter,
711+
error_channel,
709712
output,
710713
):
711714
# Handle potential errors during communication
712715
try:
713716
async with asyncio.timeout(timeout.total_seconds()):
714-
# Block for backpressure and emission errors from the ws
717+
# Block for backpressure
715718
await backpressured_waiter()
716-
result = await anext(output)
719+
# Race output and error channels
720+
raced = await _race2(anext(output), error_channel.get())
717721
except asyncio.TimeoutError as e:
718722
await self._send_cancel_stream(
719723
stream_id=stream_id,
@@ -728,17 +732,24 @@ async def send_rpc[R, A](
728732
service_name,
729733
procedure_name,
730734
) from e
735+
except Exception as e:
736+
raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e
737+
match raced:
738+
case _FinishedA(result):
739+
if "ok" not in result or not result["ok"]:
740+
try:
741+
error = error_deserializer(result["payload"])
742+
except Exception as e:
743+
raise RiverException("error_deserializer", str(e)) from e
744+
raise exception_from_message(error.code)(
745+
error.code, error.message, service_name, procedure_name
746+
)
731747

732-
if "ok" not in result or not result["ok"]:
733-
try:
734-
error = error_deserializer(result["payload"])
735-
except Exception as e:
736-
raise RiverException("error_deserializer", str(e)) from e
737-
raise exception_from_message(error.code)(
738-
error.code, error.message, service_name, procedure_name
739-
)
740-
741-
return response_deserializer(result["payload"])
748+
return response_deserializer(result["payload"])
749+
case _FinishedB(err):
750+
raise err
751+
case other:
752+
assert_never(other)
742753

743754
async def send_upload[I, R, A](
744755
self,
@@ -768,6 +779,7 @@ async def send_upload[I, R, A](
768779

769780
async with self._with_stream(span, stream_id, 1) as (
770781
backpressured_waiter,
782+
error_channel,
771783
output,
772784
):
773785
try:
@@ -776,6 +788,12 @@ async def send_upload[I, R, A](
776788
async for item in request:
777789
# Block for backpressure and emission errors from the ws
778790
await backpressured_waiter()
791+
try:
792+
raise error_channel.get_nowait()
793+
except (ChannelClosed, ChannelEmpty):
794+
# No errors, off to the next message
795+
pass
796+
779797
try:
780798
payload = request_serializer(item)
781799
except Exception as e:
@@ -867,6 +885,7 @@ async def send_subscription[I, E, A](
867885
)
868886

869887
async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as (
888+
backpressured_waiter,
870889
_,
871890
output,
872891
):
@@ -904,9 +923,12 @@ async def send_stream[I, R, E, A](
904923
error_deserializer: Callable[[Any], E],
905924
span: Span,
906925
) -> AsyncGenerator[A | E, None]:
907-
"""Sends a subscription request to the server.
926+
"""Sends a bidirectional stream request to the server.
908927
909-
Expects the input and output be messages that will be msgpacked.
928+
When the request AsyncIterable finishes, a "CLOSE" event is sent to the server.
929+
930+
If a "CLOSE" event is received from the server, the output channel will be
931+
closed, but the requests will still be sent.
910932
"""
911933

912934
stream_id = nanoid.generate()
@@ -921,19 +943,13 @@ async def send_stream[I, R, E, A](
921943

922944
async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as (
923945
backpressured_waiter,
946+
error_channel,
924947
output,
925948
):
926949
# Create the encoder task
927-
async def _encode_stream() -> None:
928-
if not request:
929-
await self._send_close_stream(
930-
stream_id=stream_id,
931-
span=span,
932-
)
933-
return
934-
935-
assert request_serializer, "send_stream missing request_serializer"
936-
950+
async def _encode_stream(
951+
request: AsyncIterable[R], request_serializer: Callable[[R], Any]
952+
) -> None:
937953
async for item in request:
938954
# Block for backpressure (or errors)
939955
await backpressured_waiter()
@@ -948,23 +964,54 @@ async def _encode_stream() -> None:
948964
span=span,
949965
)
950966

951-
emitter_task = self._task_manager.create_task(_encode_stream())
967+
emitter_task = None
968+
if not request:
969+
await self._send_close_stream(
970+
stream_id=stream_id,
971+
span=span,
972+
)
973+
else:
974+
assert request_serializer, "send_stream missing request_serializer"
975+
976+
# Now that we've validated these values, shadow them in the Task
977+
emitter_task = self._task_manager.create_task(
978+
_encode_stream(request, request_serializer)
979+
)
952980

953981
# Handle potential errors during communication
954982
try:
955983
async for result in output:
956984
# Raise as early as we possibly can in case of an emission error
957-
if emitter_task.done() and (err := emitter_task.exception()):
985+
if (
986+
emitter_task
987+
and emitter_task.done()
988+
and (err := emitter_task.exception())
989+
):
990+
raise err
991+
992+
# If the emitter channel has finished, we still need to check in
993+
# the consumer channel.
994+
try:
995+
err = error_channel.get_nowait()
996+
await self._send_close_stream(
997+
stream_id=stream_id,
998+
span=span,
999+
)
9581000
raise err
1001+
except (ChannelClosed, ChannelEmpty):
1002+
# No errors, off to the next message
1003+
pass
1004+
9591005
if result.get("type") == "CLOSE":
9601006
break
1007+
9611008
if "ok" not in result or not result["ok"]:
9621009
yield error_deserializer(result["payload"])
9631010
continue
9641011
yield response_deserializer(result["payload"])
965-
# ... block the outer function until the emitter is finished emitting,
966-
# possibly raising a terminal exception.
967-
await emitter_task
1012+
# ... block the outer function until the emitter is finished emitting.
1013+
if emitter_task:
1014+
await emitter_task
9681015
except Exception as e:
9691016
await self._send_cancel_stream(
9701017
stream_id=stream_id,
@@ -1356,3 +1403,32 @@ async def _recv_from_ws(
13561403
)
13571404
raise unhandled
13581405
logger.debug(f"_recv_from_ws exiting normally after {connection_attempts} loops")
1406+
1407+
1408+
@dataclass(frozen=True)
1409+
class _FinishedA[A]:
1410+
value: A
1411+
1412+
1413+
@dataclass(frozen=True)
1414+
class _FinishedB[B]:
1415+
value: B
1416+
1417+
1418+
async def _race2[A, B](
1419+
a: Awaitable[A],
1420+
b: Awaitable[B],
1421+
) -> _FinishedA[A] | _FinishedB[B]:
1422+
_a = asyncio.ensure_future(a)
1423+
_b = asyncio.ensure_future(b)
1424+
await asyncio.wait([_a, _b], return_when=asyncio.FIRST_COMPLETED)
1425+
1426+
if _a.done():
1427+
_b.cancel()
1428+
if err := _a.exception():
1429+
raise err
1430+
return _FinishedA(_a.result())
1431+
_a.cancel()
1432+
if err := _b.exception():
1433+
raise err
1434+
return _FinishedB(_b.result())

0 commit comments

Comments
 (0)