@@ -383,10 +383,10 @@ async def send_stream[I, R, E, A](
383383 self ,
384384 service_name : str ,
385385 procedure_name : str ,
386- init : I | None ,
387- request : AsyncIterable [R ],
388- init_serializer : Callable [[I ], Any ] | None ,
389- request_serializer : Callable [[R ], Any ],
386+ init : I ,
387+ request : AsyncIterable [R ] | None ,
388+ init_serializer : Callable [[I ], Any ],
389+ request_serializer : Callable [[R ], Any ] | None ,
390390 response_deserializer : Callable [[Any ], A ],
391391 error_deserializer : Callable [[Any ], E ],
392392 span : Span ,
@@ -399,41 +399,23 @@ async def send_stream[I, R, E, A](
399399 stream_id = nanoid .generate ()
400400 output : Channel [Any ] = Channel (MAX_MESSAGE_BUFFER_SIZE )
401401 self ._streams [stream_id ] = output
402- empty_stream = False
403402 try :
404- if init and init_serializer :
405- await self .send_message (
406- service_name = service_name ,
407- procedure_name = procedure_name ,
408- stream_id = stream_id ,
409- control_flags = STREAM_OPEN_BIT ,
410- payload = init_serializer (init ),
411- span = span ,
412- )
413- else :
414- # Get the very first message to open the stream
415- request_iter = aiter (request )
416- first = await anext (request_iter )
417- await self .send_message (
418- service_name = service_name ,
419- procedure_name = procedure_name ,
420- stream_id = stream_id ,
421- control_flags = STREAM_OPEN_BIT ,
422- payload = request_serializer (first ),
423- span = span ,
424- )
425-
426- except StopAsyncIteration :
427- empty_stream = True
428-
403+ await self .send_message (
404+ service_name = service_name ,
405+ procedure_name = procedure_name ,
406+ stream_id = stream_id ,
407+ control_flags = STREAM_OPEN_BIT ,
408+ payload = init_serializer (init ),
409+ span = span ,
410+ )
429411 except Exception as e :
430412 raise StreamClosedRiverServiceException (
431413 ERROR_CODE_STREAM_CLOSED , str (e ), service_name , procedure_name
432414 ) from e
433415
434416 # Create the encoder task
435417 async def _encode_stream () -> None :
436- if empty_stream :
418+ if not request :
437419 await self .send_close_stream (
438420 service_name ,
439421 procedure_name ,
@@ -442,6 +424,8 @@ async def _encode_stream() -> None:
442424 )
443425 return
444426
427+ assert request_serializer , "send_stream missing request_serializer"
428+
445429 async for item in request :
446430 if item is None :
447431 continue
0 commit comments