4646 RiverException ,
4747 RiverServiceException ,
4848 SessionClosedRiverServiceException ,
49- StreamClosedRiverServiceException ,
5049 exception_from_message ,
5150)
5251from replit_river .messages import (
@@ -350,9 +349,6 @@ async def _send_message(
350349 span : Span | None = None ,
351350 ) -> None :
352351 """Send serialized messages to the websockets."""
353- # if the session is not active, we should not do anything
354- if self ._state in TerminalStates :
355- return
356352 logger .debug (
357353 "_send_message(stream_id=%r, payload=%r, control_flags=%r, "
358354 "service_name=%r, procedure_name=%r)" ,
@@ -582,6 +578,16 @@ async def _with_stream(
582578 session_id : str ,
583579 maxsize : int ,
584580 ) -> AsyncIterator [tuple [asyncio .Event , Channel [ResultType ]]]:
581+ """
582+ _with_stream
583+
584+ An async context that exposes a managed stream and an event that permits
585+ producers to respond to backpressure.
586+
587+ It is expected that the first message emitted ignores this backpressure_waiter,
588+ since the first event does not care about backpressure, but subsequent events
589+ emitted should call await backpressure_waiter.wait() prior to emission.
590+ """
585591 output : Channel [Any ] = Channel (maxsize = maxsize )
586592 backpressure_waiter = asyncio .Event ()
587593 self ._streams [session_id ] = (backpressure_waiter , output )
@@ -606,15 +612,16 @@ async def send_rpc[R, A](
606612 Expects the input and output be messages that will be msgpacked.
607613 """
608614 stream_id = nanoid .generate ()
615+ await self ._send_message (
616+ stream_id = stream_id ,
617+ control_flags = STREAM_OPEN_BIT | STREAM_CLOSED_BIT ,
618+ payload = request_serializer (request ),
619+ service_name = service_name ,
620+ procedure_name = procedure_name ,
621+ span = span ,
622+ )
623+
609624 async with self ._with_stream (stream_id , 1 ) as (backpressure_waiter , output ):
610- await self ._send_message (
611- stream_id = stream_id ,
612- control_flags = STREAM_OPEN_BIT | STREAM_CLOSED_BIT ,
613- payload = request_serializer (request ),
614- service_name = service_name ,
615- procedure_name = procedure_name ,
616- span = span ,
617- )
618625 # Handle potential errors during communication
619626 try :
620627 async with asyncio .timeout (timeout .total_seconds ()):
@@ -665,19 +672,18 @@ async def send_upload[I, R, A](
665672
666673 Expects the input and output be messages that will be msgpacked.
667674 """
668-
669675 stream_id = nanoid .generate ()
676+ await self ._send_message (
677+ stream_id = stream_id ,
678+ control_flags = STREAM_OPEN_BIT ,
679+ service_name = service_name ,
680+ procedure_name = procedure_name ,
681+ payload = init_serializer (init ),
682+ span = span ,
683+ )
684+
670685 async with self ._with_stream (stream_id , 1 ) as (backpressure_waiter , output ):
671686 try :
672- await self ._send_message (
673- stream_id = stream_id ,
674- control_flags = STREAM_OPEN_BIT ,
675- service_name = service_name ,
676- procedure_name = procedure_name ,
677- payload = init_serializer (init ),
678- span = span ,
679- )
680-
681687 if request :
682688 assert request_serializer , "send_stream missing request_serializer"
683689
@@ -756,16 +762,16 @@ async def send_subscription[R, E, A](
756762 Expects the input and output be messages that will be msgpacked.
757763 """
758764 stream_id = nanoid .generate ()
759- async with self ._with_stream (stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (_ , output ):
760- await self ._send_message (
761- service_name = service_name ,
762- procedure_name = procedure_name ,
763- stream_id = stream_id ,
764- control_flags = STREAM_OPEN_BIT ,
765- payload = request_serializer (request ),
766- span = span ,
767- )
765+ await self ._send_message (
766+ service_name = service_name ,
767+ procedure_name = procedure_name ,
768+ stream_id = stream_id ,
769+ control_flags = STREAM_OPEN_BIT ,
770+ payload = request_serializer (request ),
771+ span = span ,
772+ )
768773
774+ async with self ._with_stream (stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (_ , output ):
769775 # Handle potential errors during communication
770776 try :
771777 async for item in output :
@@ -811,24 +817,19 @@ async def send_stream[I, R, E, A](
811817 """
812818
813819 stream_id = nanoid .generate ()
814- async with self ._with_stream (
815- stream_id ,
816- MAX_MESSAGE_BUFFER_SIZE ,
817- ) as (backpressure_waiter , output ):
818- try :
819- await self ._send_message (
820- service_name = service_name ,
821- procedure_name = procedure_name ,
822- stream_id = stream_id ,
823- control_flags = STREAM_OPEN_BIT ,
824- payload = init_serializer (init ),
825- span = span ,
826- )
827- except Exception as e :
828- raise StreamClosedRiverServiceException (
829- ERROR_CODE_STREAM_CLOSED , str (e ), service_name , procedure_name
830- ) from e
820+ await self ._send_message (
821+ service_name = service_name ,
822+ procedure_name = procedure_name ,
823+ stream_id = stream_id ,
824+ control_flags = STREAM_OPEN_BIT ,
825+ payload = init_serializer (init ),
826+ span = span ,
827+ )
831828
829+ async with self ._with_stream (stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (
830+ backpressure_waiter ,
831+ output ,
832+ ):
832833 # Create the encoder task
833834 async def _encode_stream () -> None :
834835 if not request :
0 commit comments