@@ -244,10 +244,10 @@ async def send_upload[I, R, A](
244244 self ,
245245 service_name : str ,
246246 procedure_name : str ,
247- init : I | None ,
248- request : AsyncIterable [R ],
249- init_serializer : Callable [[I ], Any ] | None ,
250- request_serializer : Callable [[R ], Any ],
247+ init : I ,
248+ request : AsyncIterable [R ] | None ,
249+ init_serializer : Callable [[I ], Any ],
250+ request_serializer : Callable [[R ], Any ] | None ,
251251 response_deserializer : Callable [[Any ], A ],
252252 error_deserializer : Callable [[Any ], RiverError ],
253253 span : Span ,
@@ -260,33 +260,31 @@ async def send_upload[I, R, A](
260260 stream_id = nanoid .generate ()
261261 output : Channel [Any ] = Channel (1 )
262262 self ._streams [stream_id ] = output
263- first_message = True
264263 try :
265- if init and init_serializer :
266- await self .send_message (
267- stream_id = stream_id ,
268- control_flags = STREAM_OPEN_BIT ,
269- service_name = service_name ,
270- procedure_name = procedure_name ,
271- payload = init_serializer (init ),
272- span = span ,
273- )
274- first_message = False
275- # If this request is not closed and the session is killed, we should
276- # throw exception here
277- async for item in request :
278- control_flags = 0
279- if first_message :
280- control_flags = STREAM_OPEN_BIT
281- first_message = False
282- await self .send_message (
283- stream_id = stream_id ,
284- service_name = service_name ,
285- procedure_name = procedure_name ,
286- control_flags = control_flags ,
287- payload = request_serializer (item ),
288- span = span ,
289- )
264+ await self .send_message (
265+ stream_id = stream_id ,
266+ control_flags = STREAM_OPEN_BIT ,
267+ service_name = service_name ,
268+ procedure_name = procedure_name ,
269+ payload = init_serializer (init ),
270+ span = span ,
271+ )
272+
273+ if request :
274+ assert request_serializer , "send_stream missing request_serializer"
275+
276+ # If this request is not closed and the session is killed, we should
277+ # throw exception here
278+ async for item in request :
279+ control_flags = 0
280+ await self .send_message (
281+ stream_id = stream_id ,
282+ service_name = service_name ,
283+ procedure_name = procedure_name ,
284+ control_flags = control_flags ,
285+ payload = request_serializer (item ),
286+ span = span ,
287+ )
290288 except Exception as e :
291289 raise RiverServiceException (
292290 ERROR_CODE_STREAM_CLOSED , str (e ), service_name , procedure_name
@@ -295,7 +293,7 @@ async def send_upload[I, R, A](
295293 service_name ,
296294 procedure_name ,
297295 stream_id ,
298- extra_control_flags = STREAM_OPEN_BIT if first_message else 0 ,
296+ extra_control_flags = 0 ,
299297 )
300298
301299 # Handle potential errors during communication
0 commit comments