Skip to content

Commit 028a49a

Browse files
Clarify _with_stream semantics
1 parent 4deb45f commit 028a49a

File tree

1 file changed

+49
-48
lines changed

1 file changed

+49
-48
lines changed

src/replit_river/v2/session.py

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
RiverException,
4747
RiverServiceException,
4848
SessionClosedRiverServiceException,
49-
StreamClosedRiverServiceException,
5049
exception_from_message,
5150
)
5251
from replit_river.messages import (
@@ -350,9 +349,6 @@ async def _send_message(
350349
span: Span | None = None,
351350
) -> None:
352351
"""Send serialized messages to the websockets."""
353-
# if the session is not active, we should not do anything
354-
if self._state in TerminalStates:
355-
return
356352
logger.debug(
357353
"_send_message(stream_id=%r, payload=%r, control_flags=%r, "
358354
"service_name=%r, procedure_name=%r)",
@@ -582,6 +578,16 @@ async def _with_stream(
582578
session_id: str,
583579
maxsize: int,
584580
) -> AsyncIterator[tuple[asyncio.Event, Channel[ResultType]]]:
581+
"""
582+
_with_stream
583+
584+
An async context that exposes a managed stream and an event that permits
585+
producers to respond to backpressure.
586+
587+
It is expected that the first message emitted ignores this backpressure_waiter,
588+
since the first event does not care about backpressure, but subsequent events
589+
emitted should call await backpressure_waiter.wait() prior to emission.
590+
"""
585591
output: Channel[Any] = Channel(maxsize=maxsize)
586592
backpressure_waiter = asyncio.Event()
587593
self._streams[session_id] = (backpressure_waiter, output)
@@ -606,15 +612,16 @@ async def send_rpc[R, A](
606612
Expects the input and output be messages that will be msgpacked.
607613
"""
608614
stream_id = nanoid.generate()
615+
await self._send_message(
616+
stream_id=stream_id,
617+
control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT,
618+
payload=request_serializer(request),
619+
service_name=service_name,
620+
procedure_name=procedure_name,
621+
span=span,
622+
)
623+
609624
async with self._with_stream(stream_id, 1) as (backpressure_waiter, output):
610-
await self._send_message(
611-
stream_id=stream_id,
612-
control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT,
613-
payload=request_serializer(request),
614-
service_name=service_name,
615-
procedure_name=procedure_name,
616-
span=span,
617-
)
618625
# Handle potential errors during communication
619626
try:
620627
async with asyncio.timeout(timeout.total_seconds()):
@@ -665,19 +672,18 @@ async def send_upload[I, R, A](
665672
666673
Expects the input and output be messages that will be msgpacked.
667674
"""
668-
669675
stream_id = nanoid.generate()
676+
await self._send_message(
677+
stream_id=stream_id,
678+
control_flags=STREAM_OPEN_BIT,
679+
service_name=service_name,
680+
procedure_name=procedure_name,
681+
payload=init_serializer(init),
682+
span=span,
683+
)
684+
670685
async with self._with_stream(stream_id, 1) as (backpressure_waiter, output):
671686
try:
672-
await self._send_message(
673-
stream_id=stream_id,
674-
control_flags=STREAM_OPEN_BIT,
675-
service_name=service_name,
676-
procedure_name=procedure_name,
677-
payload=init_serializer(init),
678-
span=span,
679-
)
680-
681687
if request:
682688
assert request_serializer, "send_stream missing request_serializer"
683689

@@ -756,16 +762,16 @@ async def send_subscription[R, E, A](
756762
Expects the input and output be messages that will be msgpacked.
757763
"""
758764
stream_id = nanoid.generate()
759-
async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (_, output):
760-
await self._send_message(
761-
service_name=service_name,
762-
procedure_name=procedure_name,
763-
stream_id=stream_id,
764-
control_flags=STREAM_OPEN_BIT,
765-
payload=request_serializer(request),
766-
span=span,
767-
)
765+
await self._send_message(
766+
service_name=service_name,
767+
procedure_name=procedure_name,
768+
stream_id=stream_id,
769+
control_flags=STREAM_OPEN_BIT,
770+
payload=request_serializer(request),
771+
span=span,
772+
)
768773

774+
async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (_, output):
769775
# Handle potential errors during communication
770776
try:
771777
async for item in output:
@@ -811,24 +817,19 @@ async def send_stream[I, R, E, A](
811817
"""
812818

813819
stream_id = nanoid.generate()
814-
async with self._with_stream(
815-
stream_id,
816-
MAX_MESSAGE_BUFFER_SIZE,
817-
) as (backpressure_waiter, output):
818-
try:
819-
await self._send_message(
820-
service_name=service_name,
821-
procedure_name=procedure_name,
822-
stream_id=stream_id,
823-
control_flags=STREAM_OPEN_BIT,
824-
payload=init_serializer(init),
825-
span=span,
826-
)
827-
except Exception as e:
828-
raise StreamClosedRiverServiceException(
829-
ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name
830-
) from e
820+
await self._send_message(
821+
service_name=service_name,
822+
procedure_name=procedure_name,
823+
stream_id=stream_id,
824+
control_flags=STREAM_OPEN_BIT,
825+
payload=init_serializer(init),
826+
span=span,
827+
)
831828

829+
async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (
830+
backpressure_waiter,
831+
output,
832+
):
832833
# Create the encoder task
833834
async def _encode_stream() -> None:
834835
if not request:

0 commit comments

Comments
 (0)