@@ -83,10 +83,10 @@ async def send_upload(
8383 self ,
8484 service_name : str ,
8585 procedure_name : str ,
86- init : Optional [ InitType ] ,
87- request : AsyncIterable [RequestType ],
88- init_serializer : Optional [ Callable [[InitType ], Any ] ],
89- request_serializer : Callable [[RequestType ], Any ],
86+ init : InitType ,
87+ request : Optional [ AsyncIterable [RequestType ] ],
88+ init_serializer : Callable [[InitType ], Any ],
89+ request_serializer : Optional [ Callable [[RequestType ], Any ] ],
9090 response_deserializer : Callable [[Any ], ResponseType ],
9191 error_deserializer : Callable [[Any ], ErrorType ],
9292 ) -> ResponseType :
@@ -100,29 +100,26 @@ async def send_upload(
100100 self ._streams [stream_id ] = output
101101 first_message = True
102102 try :
103- if init and init_serializer :
104- await self .send_message (
105- stream_id = stream_id ,
106- control_flags = STREAM_OPEN_BIT ,
107- service_name = service_name ,
108- procedure_name = procedure_name ,
109- payload = init_serializer (init ),
110- )
111- first_message = False
112- # If this request is not closed and the session is killed, we should
113- # throw exception here
114- async for item in request :
115- control_flags = 0
116- if first_message :
117- control_flags = STREAM_OPEN_BIT
118- first_message = False
119- await self .send_message (
120- stream_id = stream_id ,
121- service_name = service_name ,
122- procedure_name = procedure_name ,
123- control_flags = control_flags ,
124- payload = request_serializer (item ),
125- )
103+ await self .send_message (
104+ stream_id = stream_id ,
105+ control_flags = STREAM_OPEN_BIT ,
106+ service_name = service_name ,
107+ procedure_name = procedure_name ,
108+ payload = init_serializer (init ),
109+ )
110+ first_message = False
111+ if request is not None and request_serializer is not None :
112+ # If this request is not closed and the session is killed, we should
113+ # throw exception here
114+ async for item in request :
115+ control_flags = 0
116+ await self .send_message (
117+ stream_id = stream_id ,
118+ service_name = service_name ,
119+ procedure_name = procedure_name ,
120+ control_flags = control_flags ,
121+ payload = request_serializer (item ),
122+ )
126123 except Exception as e :
127124 raise RiverServiceException (
128125 ERROR_CODE_STREAM_CLOSED , str (e ), service_name , procedure_name
@@ -215,10 +212,10 @@ async def send_stream(
215212 self ,
216213 service_name : str ,
217214 procedure_name : str ,
218- init : Optional [ InitType ] ,
219- request : AsyncIterable [RequestType ],
220- init_serializer : Optional [ Callable [[InitType ], Any ] ],
221- request_serializer : Callable [[RequestType ], Any ],
215+ init : InitType ,
216+ request : Optional [ AsyncIterable [RequestType ] ],
217+ init_serializer : Callable [[InitType ], Any ],
218+ request_serializer : Optional [ Callable [[RequestType ], Any ] ],
222219 response_deserializer : Callable [[Any ], ResponseType ],
223220 error_deserializer : Callable [[Any ], ErrorType ],
224221 ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
@@ -230,60 +227,36 @@ async def send_stream(
230227 stream_id = nanoid .generate ()
231228 output : Channel [Any ] = Channel (MAX_MESSAGE_BUFFER_SIZE )
232229 self ._streams [stream_id ] = output
233- empty_stream = False
234230 try :
235- if init and init_serializer :
236- await self .send_message (
237- service_name = service_name ,
238- procedure_name = procedure_name ,
239- stream_id = stream_id ,
240- control_flags = STREAM_OPEN_BIT ,
241- payload = init_serializer (init ),
242- )
243- else :
244- # Get the very first message to open the stream
245- request_iter = aiter (request )
246- first = await anext (request_iter )
247- await self .send_message (
248- service_name = service_name ,
249- procedure_name = procedure_name ,
250- stream_id = stream_id ,
251- control_flags = STREAM_OPEN_BIT ,
252- payload = request_serializer (first ),
253- )
254-
255- except StopAsyncIteration :
256- empty_stream = True
231+ await self .send_message (
232+ service_name = service_name ,
233+ procedure_name = procedure_name ,
234+ stream_id = stream_id ,
235+ control_flags = STREAM_OPEN_BIT ,
236+ payload = init_serializer (init ),
237+ )
257238
258239 except Exception as e :
259240 raise StreamClosedRiverServiceException (
260241 ERROR_CODE_STREAM_CLOSED , str (e ), service_name , procedure_name
261242 ) from e
262243
263- # Create the encoder task
264- async def _encode_stream () -> None :
265- if empty_stream :
266- await self .send_close_stream (
267- service_name ,
268- procedure_name ,
269- stream_id ,
270- extra_control_flags = STREAM_OPEN_BIT ,
271- )
272- return
273-
274- async for item in request :
275- if item is None :
276- continue
277- await self .send_message (
278- service_name = service_name ,
279- procedure_name = procedure_name ,
280- stream_id = stream_id ,
281- control_flags = 0 ,
282- payload = request_serializer (item ),
283- )
284- await self .send_close_stream (service_name , procedure_name , stream_id )
244+ if request is not None and request_serializer is not None :
245+ # Create the encoder task
246+ async def _encode_stream () -> None :
247+ async for item in request :
248+ if item is None :
249+ continue
250+ await self .send_message (
251+ service_name = service_name ,
252+ procedure_name = procedure_name ,
253+ stream_id = stream_id ,
254+ control_flags = 0 ,
255+ payload = request_serializer (item ),
256+ )
257+ await self .send_close_stream (service_name , procedure_name , stream_id )
285258
286- self ._task_manager .create_task (_encode_stream ())
259+ self ._task_manager .create_task (_encode_stream ())
287260
288261 # Handle potential errors during communication
289262 try :
0 commit comments