2121
2222import nanoid
2323import websockets .asyncio .client
24- from aiochannel import Channel , ChannelFull
24+ from aiochannel import Channel , ChannelEmpty , ChannelFull
2525from aiochannel .errors import ChannelClosed
2626from opentelemetry .trace import Span , use_span
2727from opentelemetry .trace .propagation .tracecontext import TraceContextTextMapPropagator
8383STREAM_CLOSED_BIT : STREAM_CLOSED_BIT_TYPE = 0b01000
8484
8585
86+ _BackpressuredWaiter : TypeAlias = Callable [[], Awaitable [None ]]
87+
88+
8689class ResultOk (TypedDict ):
8790 ok : Literal [True ]
8891 payload : Any
@@ -120,7 +123,8 @@ class _IgnoreMessage:
120123
121124class StreamMeta (TypedDict ):
122125 span : Span
123- error_channel : Channel [None | Exception ]
126+ release_backpressured_waiter : Callable [[], None ]
127+ error_channel : Channel [Exception ]
124128 output : Channel [Any ]
125129
126130
@@ -417,6 +421,7 @@ async def close(self, reason: Exception | None = None) -> None:
417421 logger .exception (
418422 "Unable to tell the caller that the session is going away" ,
419423 )
424+ stream_meta ["release_backpressured_waiter" ]()
420425 # Before we GC the streams, let's wait for all tasks to be closed gracefully.
421426 await asyncio .gather (
422427 * [stream_meta ["output" ].join () for stream_meta in self ._streams .values ()]
@@ -473,7 +478,7 @@ async def commit(msg: TransportMessage) -> None:
473478 # Wake up backpressured writer
474479 stream_meta = self ._streams .get (pending .streamId )
475480 if stream_meta :
476- await stream_meta ["error_channel" ]. put ( None )
481+ stream_meta ["release_backpressured_waiter" ]( )
477482
478483 def get_next_pending () -> TransportMessage | None :
479484 if self ._send_buffer :
@@ -584,7 +589,7 @@ async def _with_stream(
584589 span : Span ,
585590 stream_id : str ,
586591 maxsize : int ,
587- ) -> AsyncIterator [tuple [Channel [ None | Exception ], Channel [ResultType ]]]:
592+ ) -> AsyncIterator [tuple [_BackpressuredWaiter , AsyncIterator [ResultType ]]]:
588593 """
589594 _with_stream
590595
@@ -596,14 +601,36 @@ async def _with_stream(
596601 emitted should call await error_channel.wait() prior to emission.
597602 """
598603 output : Channel [Any ] = Channel (maxsize = maxsize )
599- error_channel : Channel [None | Exception ] = Channel (maxsize = 1 )
604+ backpressured_waiter_event : asyncio .Event = asyncio .Event ()
605+ error_channel : Channel [Exception ] = Channel (maxsize = 1 )
600606 self ._streams [stream_id ] = {
601607 "span" : span ,
602608 "error_channel" : error_channel ,
609+ "release_backpressured_waiter" : backpressured_waiter_event .set ,
603610 "output" : output ,
604611 }
612+
613+ async def backpressured_waiter () -> None :
614+ await backpressured_waiter_event .wait ()
615+ try :
616+ err = error_channel .get_nowait ()
617+ raise err
618+ except (ChannelClosed , ChannelEmpty ):
619+ # No errors, off to the next message
620+ pass
621+
622+ async def error_checking_output () -> AsyncIterator [ResultType ]:
623+ async for elem in output :
624+ try :
625+ err = error_channel .get_nowait ()
626+ raise err
627+ except (ChannelClosed , ChannelEmpty ):
628+ # No errors, off to the next message
629+ pass
630+ yield elem
631+
605632 try :
606- yield (error_channel , output )
633+ yield (backpressured_waiter , error_checking_output () )
607634 finally :
608635 stream_meta = self ._streams .get (stream_id )
609636 if not stream_meta :
@@ -644,14 +671,16 @@ async def send_rpc[R, A](
644671 span = span ,
645672 )
646673
647- async with self ._with_stream (span , stream_id , 1 ) as (error_channel , output ):
674+ async with self ._with_stream (span , stream_id , 1 ) as (
675+ backpressured_waiter ,
676+ output ,
677+ ):
648678 # Handle potential errors during communication
649679 try :
650680 async with asyncio .timeout (timeout .total_seconds ()):
651681 # Block for backpressure and emission errors from the ws
652- if err := await error_channel .get ():
653- raise err
654- result = await output .get ()
682+ await backpressured_waiter ()
683+ result = await anext (output )
655684 except asyncio .TimeoutError as e :
656685 await self ._send_cancel_stream (
657686 stream_id = stream_id ,
@@ -705,15 +734,16 @@ async def send_upload[I, R, A](
705734 span = span ,
706735 )
707736
708- async with self ._with_stream (span , stream_id , 1 ) as (error_channel , output ):
737+ async with self ._with_stream (span , stream_id , 1 ) as (
738+ backpressured_waiter ,
739+ output ,
740+ ):
709741 try :
710742 # If this request is not closed and the session is killed, we should
711743 # throw exception here
712744 async for item in request :
713745 # Block for backpressure and emission errors from the ws
714- if err := await error_channel .get ():
715- raise err
716-
746+ await backpressured_waiter ()
717747 try :
718748 payload = request_serializer (item )
719749 except Exception as e :
@@ -753,7 +783,7 @@ async def send_upload[I, R, A](
753783 )
754784
755785 try :
756- result = await output . get ( )
786+ result = await anext ( output )
757787 except ChannelClosed as e :
758788 raise RiverServiceException (
759789 ERROR_CODE_STREAM_CLOSED ,
@@ -856,7 +886,7 @@ async def send_stream[I, R, E, A](
856886 )
857887
858888 async with self ._with_stream (span , stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (
859- error_channel ,
889+ backpressured_waiter ,
860890 output ,
861891 ):
862892 # Create the encoder task
@@ -871,13 +901,9 @@ async def _encode_stream() -> None:
871901 assert request_serializer , "send_stream missing request_serializer"
872902
873903 async for item in request :
874- # Block for backpressure and emission errors from the ws
875- if err := await error_channel .get ():
876- await self ._send_close_stream (
877- stream_id = stream_id ,
878- span = span ,
879- )
880- raise err
904+ # Block for backpressure (or errors)
905+ await backpressured_waiter ()
906+ # If there are any errors so far, raise them
881907 await self ._enqueue_message (
882908 stream_id = stream_id ,
883909 control_flags = 0 ,
@@ -894,14 +920,15 @@ async def _encode_stream() -> None:
894920 try :
895921 async for result in output :
896922 # Raise as early as we possibly can in case of an emission error
897- if err := emitter_task .done () and emitter_task .exception ():
923+ if emitter_task .done () and ( err := emitter_task .exception () ):
898924 raise err
899925 if result .get ("type" ) == "CLOSE" :
900926 break
901927 if "ok" not in result or not result ["ok" ]:
902928 yield error_deserializer (result ["payload" ])
903929 yield response_deserializer (result ["payload" ])
904- # ... block the outer function until the emitter is finished emitting.
930+ # ... block the outer function until the emitter is finished emitting,
931+ # possibly raising a terminal exception.
905932 await emitter_task
906933 except Exception as e :
907934 await self ._send_cancel_stream (
0 commit comments