@@ -222,120 +222,120 @@ async def _do_ensure_connected[HandshakeMetadata](
222222
223223 rate_limiter .consume_budget (client_id )
224224
225+ ws = None
225226 try :
226227 uri_and_metadata = await uri_and_metadata_factory ()
227228 ws = await websockets .asyncio .client .connect (uri_and_metadata ["uri" ])
228229
229230 try :
230- try :
231- next_seq = 0
232- if self ._send_buffer :
233- next_seq = self ._send_buffer [0 ].seq
234- handshake_request = ControlMessageHandshakeRequest [
235- HandshakeMetadata
236- ]( # noqa: E501
237- type = "HANDSHAKE_REQ" ,
238- protocolVersion = protocol_version ,
239- sessionId = self .session_id ,
240- metadata = uri_and_metadata ["metadata" ],
241- expectedSessionState = ExpectedSessionState (
242- nextExpectedSeq = self .ack ,
243- nextSentSeq = next_seq ,
244- ),
231+ next_seq = 0
232+ if self ._send_buffer :
233+ next_seq = self ._send_buffer [0 ].seq
234+ handshake_request = ControlMessageHandshakeRequest [
235+ HandshakeMetadata
236+ ]( # noqa: E501
237+ type = "HANDSHAKE_REQ" ,
238+ protocolVersion = protocol_version ,
239+ sessionId = self .session_id ,
240+ metadata = uri_and_metadata ["metadata" ],
241+ expectedSessionState = ExpectedSessionState (
242+ nextExpectedSeq = self .ack ,
243+ nextSentSeq = next_seq ,
244+ ),
245+ )
246+ stream_id = nanoid .generate ()
247+
248+ async def websocket_closed_callback () -> None :
249+ logger .error ("websocket closed before handshake response" )
250+
251+ await send_transport_message (
252+ TransportMessage (
253+ from_ = self ._transport_id ,
254+ to = self ._to_id ,
255+ streamId = stream_id ,
256+ controlFlags = 0 ,
257+ id = nanoid .generate (),
258+ seq = 0 ,
259+ ack = 0 ,
260+ payload = handshake_request .model_dump (),
261+ ),
262+ ws = ws ,
263+ websocket_closed_callback = websocket_closed_callback ,
264+ )
265+ except (
266+ WebsocketClosedException ,
267+ FailedSendingMessageException ,
268+ ) as e : # noqa: E501
269+ raise RiverException (
270+ ERROR_HANDSHAKE ,
271+ "Handshake failed, conn closed while sending response" , # noqa: E501
272+ ) from e
273+
274+ startup_grace_deadline_ms = await self ._get_current_time () + 60_000
275+ while True :
276+ if await self ._get_current_time () >= startup_grace_deadline_ms : # noqa: E501
277+ raise RiverException (
278+ ERROR_HANDSHAKE ,
279+ "Handshake response timeout, closing connection" , # noqa: E501
245280 )
246- stream_id = nanoid .generate ()
247-
248- async def websocket_closed_callback () -> None :
249- logger .error ("websocket closed before handshake response" )
250-
251- await send_transport_message (
252- TransportMessage (
253- from_ = self ._transport_id ,
254- to = self ._to_id ,
255- streamId = stream_id ,
256- controlFlags = 0 ,
257- id = nanoid .generate (),
258- seq = 0 ,
259- ack = 0 ,
260- payload = handshake_request .model_dump (),
261- ),
262- ws = ws ,
263- websocket_closed_callback = websocket_closed_callback ,
281+ try :
282+ data = await ws .recv (decode = False )
283+ except ConnectionClosed as e :
284+ logger .debug (
285+ "Connection closed during waiting for handshake response" , # noqa: E501
286+ exc_info = True ,
264287 )
265- except (
266- WebsocketClosedException ,
267- FailedSendingMessageException ,
268- ) as e : # noqa: E501
269288 raise RiverException (
270289 ERROR_HANDSHAKE ,
271- "Handshake failed, conn closed while sending response" , # noqa: E501
290+ "Handshake failed, conn closed while waiting for response" , # noqa: E501
272291 ) from e
273292
274- startup_grace_deadline_ms = await self ._get_current_time () + 60_000
275- while True :
276- if await self ._get_current_time () >= startup_grace_deadline_ms : # noqa: E501
277- raise RiverException (
278- ERROR_HANDSHAKE ,
279- "Handshake response timeout, closing connection" , # noqa: E501
280- )
281- try :
282- data = await ws .recv ()
283- except ConnectionClosed as e :
284- logger .debug (
285- "Connection closed during waiting for handshake response" , # noqa: E501
286- exc_info = True ,
287- )
288- raise RiverException (
289- ERROR_HANDSHAKE ,
290- "Handshake failed, conn closed while waiting for response" , # noqa: E501
291- ) from e
292- try :
293- response_msg = parse_transport_msg (data )
294- break
295- except IgnoreMessageException :
296- logger .debug ("Ignoring transport message" , exc_info = True ) # noqa: E501
297- continue
298- except InvalidMessageException as e :
299- raise RiverException (
300- ERROR_HANDSHAKE ,
301- "Got invalid transport message, closing connection" ,
302- ) from e
303-
304293 try :
305- handshake_response = ControlMessageHandshakeResponse (
306- ** response_msg .payload
307- )
308- logger .debug ("river client waiting for handshake response" )
309- except ValidationError as e :
294+ response_msg = parse_transport_msg (data )
295+ break
296+ except IgnoreMessageException :
297+ logger .debug ("Ignoring transport message" , exc_info = True ) # noqa: E501
298+ continue
299+ except InvalidMessageException as e :
310300 raise RiverException (
311- ERROR_HANDSHAKE , "Failed to parse handshake response"
301+ ERROR_HANDSHAKE ,
302+ "Got invalid transport message, closing connection" ,
312303 ) from e
313304
314- logger .debug (
315- "river client get handshake response : %r" , handshake_response
316- ) # noqa: E501
317- if not handshake_response .status .ok :
318- if (
319- handshake_response .status .code
320- == ERROR_CODE_SESSION_STATE_MISMATCH
321- ): # noqa: E501
322- await self .close ()
323- raise RiverException (
324- ERROR_HANDSHAKE ,
325- f"Handshake failed with code { handshake_response .status .code } : " # noqa: E501
326- f"{ handshake_response .status .reason } " ,
327- )
305+ try :
306+ handshake_response = ControlMessageHandshakeResponse (
307+ ** response_msg .payload
308+ )
309+ logger .debug ("river client waiting for handshake response" )
310+ except ValidationError as e :
311+ raise RiverException (
312+ ERROR_HANDSHAKE , "Failed to parse handshake response"
313+ ) from e
328314
329- last_error = None
330- rate_limiter .start_restoring_budget (client_id )
331- self ._state = SessionState .ACTIVE
332- self ._ws_unwrapped = ws
333- self ._connection_condition .notify_all ()
334- break
335- except RiverException as e :
336- await ws .close ()
337- raise e
315+ logger .debug (
316+ "river client get handshake response : %r" , handshake_response
317+ ) # noqa: E501
318+ if not handshake_response .status .ok :
319+ if (
320+ handshake_response .status .code
321+ == ERROR_CODE_SESSION_STATE_MISMATCH
322+ ): # noqa: E501
323+ await self .close ()
324+ raise RiverException (
325+ ERROR_HANDSHAKE ,
326+ f"Handshake failed with code { handshake_response .status .code } : " # noqa: E501
327+ f"{ handshake_response .status .reason } " ,
328+ )
329+
330+ last_error = None
331+ rate_limiter .start_restoring_budget (client_id )
332+ self ._state = SessionState .ACTIVE
333+ self ._ws_unwrapped = ws
334+ self ._connection_condition .notify_all ()
335+ break
338336 except Exception as e :
337+ if ws :
338+ await ws .close ()
339339 last_error = e
340340 backoff_time = rate_limiter .get_backoff_ms (client_id )
341341 logger .exception (
0 commit comments