@@ -622,7 +622,9 @@ async def _with_stream(
622622 span : Span ,
623623 stream_id : str ,
624624 maxsize : int ,
625- ) -> AsyncIterator [tuple [_BackpressuredWaiter , AsyncIterator [ResultType ]]]:
625+ ) -> AsyncIterator [
626+ tuple [_BackpressuredWaiter , Channel [Exception ], AsyncIterator [ResultType ]]
627+ ]:
626628 """
627629 _with_stream
628630
@@ -663,7 +665,7 @@ async def error_checking_output() -> AsyncIterator[ResultType]:
663665 yield elem
664666
665667 try :
666- yield (backpressured_waiter , error_checking_output ())
668+ yield (backpressured_waiter , error_channel , error_checking_output ())
667669 finally :
668670 stream_meta = self ._streams .get (stream_id )
669671 if not stream_meta :
@@ -706,14 +708,16 @@ async def send_rpc[R, A](
706708
707709 async with self ._with_stream (span , stream_id , 1 ) as (
708710 backpressured_waiter ,
711+ error_channel ,
709712 output ,
710713 ):
711714 # Handle potential errors during communication
712715 try :
713716 async with asyncio .timeout (timeout .total_seconds ()):
714- # Block for backpressure and emission errors from the ws
717+ # Block for backpressure
715718 await backpressured_waiter ()
716- result = await anext (output )
719+ # Race output and error channels
720+ raced = await _race2 (anext (output ), error_channel .get ())
717721 except asyncio .TimeoutError as e :
718722 await self ._send_cancel_stream (
719723 stream_id = stream_id ,
@@ -728,17 +732,24 @@ async def send_rpc[R, A](
728732 service_name ,
729733 procedure_name ,
730734 ) from e
735+ except Exception as e :
736+ raise RiverException (ERROR_CODE_STREAM_CLOSED , str (e )) from e
737+ match raced :
738+ case _FinishedA (result ):
739+ if "ok" not in result or not result ["ok" ]:
740+ try :
741+ error = error_deserializer (result ["payload" ])
742+ except Exception as e :
743+ raise RiverException ("error_deserializer" , str (e )) from e
744+ raise exception_from_message (error .code )(
745+ error .code , error .message , service_name , procedure_name
746+ )
731747
732- if "ok" not in result or not result ["ok" ]:
733- try :
734- error = error_deserializer (result ["payload" ])
735- except Exception as e :
736- raise RiverException ("error_deserializer" , str (e )) from e
737- raise exception_from_message (error .code )(
738- error .code , error .message , service_name , procedure_name
739- )
740-
741- return response_deserializer (result ["payload" ])
748+ return response_deserializer (result ["payload" ])
749+ case _FinishedB (err ):
750+ raise err
751+ case other :
752+ assert_never (other )
742753
743754 async def send_upload [I , R , A ](
744755 self ,
@@ -768,6 +779,7 @@ async def send_upload[I, R, A](
768779
769780 async with self ._with_stream (span , stream_id , 1 ) as (
770781 backpressured_waiter ,
782+ error_channel ,
771783 output ,
772784 ):
773785 try :
@@ -776,6 +788,12 @@ async def send_upload[I, R, A](
776788 async for item in request :
777789 # Block for backpressure and emission errors from the ws
778790 await backpressured_waiter ()
791+ try :
792+ raise error_channel .get_nowait ()
793+ except (ChannelClosed , ChannelEmpty ):
794+ # No errors, off to the next message
795+ pass
796+
779797 try :
780798 payload = request_serializer (item )
781799 except Exception as e :
@@ -867,6 +885,7 @@ async def send_subscription[I, E, A](
867885 )
868886
869887 async with self ._with_stream (span , stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (
888+ backpressured_waiter ,
870889 _ ,
871890 output ,
872891 ):
@@ -904,9 +923,12 @@ async def send_stream[I, R, E, A](
904923 error_deserializer : Callable [[Any ], E ],
905924 span : Span ,
906925 ) -> AsyncGenerator [A | E , None ]:
907- """Sends a subscription request to the server.
926+ """Sends a bidirectional stream request to the server.
908927
909- Expects the input and output be messages that will be msgpacked.
928+ When the request AsyncIterable finishes, a "CLOSE" event is sent to the server.
929+
930+ If a "CLOSE" event is received from the server, the output channel will be
931+ closed, but the requests will still be sent.
910932 """
911933
912934 stream_id = nanoid .generate ()
@@ -921,19 +943,13 @@ async def send_stream[I, R, E, A](
921943
922944 async with self ._with_stream (span , stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (
923945 backpressured_waiter ,
946+ error_channel ,
924947 output ,
925948 ):
926949 # Create the encoder task
927- async def _encode_stream () -> None :
928- if not request :
929- await self ._send_close_stream (
930- stream_id = stream_id ,
931- span = span ,
932- )
933- return
934-
935- assert request_serializer , "send_stream missing request_serializer"
936-
950+ async def _encode_stream (
951+ request : AsyncIterable [R ], request_serializer : Callable [[R ], Any ]
952+ ) -> None :
937953 async for item in request :
938954 # Block for backpressure (or errors)
939955 await backpressured_waiter ()
@@ -948,23 +964,54 @@ async def _encode_stream() -> None:
948964 span = span ,
949965 )
950966
951- emitter_task = self ._task_manager .create_task (_encode_stream ())
967+ emitter_task = None
968+ if not request :
969+ await self ._send_close_stream (
970+ stream_id = stream_id ,
971+ span = span ,
972+ )
973+ else :
974+ assert request_serializer , "send_stream missing request_serializer"
975+
976+ # Now that we've validated these values, shadow them in the Task
977+ emitter_task = self ._task_manager .create_task (
978+ _encode_stream (request , request_serializer )
979+ )
952980
953981 # Handle potential errors during communication
954982 try :
955983 async for result in output :
956984 # Raise as early as we possibly can in case of an emission error
957- if emitter_task .done () and (err := emitter_task .exception ()):
985+ if (
986+ emitter_task
987+ and emitter_task .done ()
988+ and (err := emitter_task .exception ())
989+ ):
990+ raise err
991+
992+ # If the emitter channel has finished, we still need to check in
993+ # the consumer channel.
994+ try :
995+ err = error_channel .get_nowait ()
996+ await self ._send_close_stream (
997+ stream_id = stream_id ,
998+ span = span ,
999+ )
9581000 raise err
1001+ except (ChannelClosed , ChannelEmpty ):
1002+ # No errors, off to the next message
1003+ pass
1004+
9591005 if result .get ("type" ) == "CLOSE" :
9601006 break
1007+
9611008 if "ok" not in result or not result ["ok" ]:
9621009 yield error_deserializer (result ["payload" ])
9631010 continue
9641011 yield response_deserializer (result ["payload" ])
965- # ... block the outer function until the emitter is finished emitting,
966- # possibly raising a terminal exception.
967- await emitter_task
1012+ # ... block the outer function until the emitter is finished emitting.
1013+ if emitter_task :
1014+ await emitter_task
9681015 except Exception as e :
9691016 await self ._send_cancel_stream (
9701017 stream_id = stream_id ,
@@ -1356,3 +1403,32 @@ async def _recv_from_ws(
13561403 )
13571404 raise unhandled
13581405 logger .debug (f"_recv_from_ws exiting normally after { connection_attempts } loops" )
1406+
1407+
1408+ @dataclass (frozen = True )
1409+ class _FinishedA [A ]:
1410+ value : A
1411+
1412+
1413+ @dataclass (frozen = True )
1414+ class _FinishedB [B ]:
1415+ value : B
1416+
1417+
1418+ async def _race2 [A , B ](
1419+ a : Awaitable [A ],
1420+ b : Awaitable [B ],
1421+ ) -> _FinishedA [A ] | _FinishedB [B ]:
1422+ _a = asyncio .ensure_future (a )
1423+ _b = asyncio .ensure_future (b )
1424+ await asyncio .wait ([_a , _b ], return_when = asyncio .FIRST_COMPLETED )
1425+
1426+ if _a .done ():
1427+ _b .cancel ()
1428+ if err := _a .exception ():
1429+ raise err
1430+ return _FinishedA (_a .result ())
1431+ _a .cancel ()
1432+ if err := _b .exception ():
1433+ raise err
1434+ return _FinishedB (_b .result ())
0 commit comments