7171)
7272from replit_river .v2 .client_transport import (
7373 PROTOCOL_VERSION ,
74- HandshakeBudgetExhaustedException ,
7574)
7675
7776STREAM_CANCEL_BIT_TYPE = Literal [0b00100 ]
@@ -202,11 +201,13 @@ def get_next_pending() -> TransportMessage | None:
202201 )
203202 )
204203
205- async def ensure_connected (
204+ async def ensure_connected [ HandshakeMetadata ] (
206205 self ,
207206 client_id : str ,
208207 rate_limiter : LeakyBucketRateLimit ,
209- uri_and_metadata_factory : Callable [[], Awaitable [UriAndMetadata ]],
208+ uri_and_metadata_factory : Callable [
209+ [], Awaitable [UriAndMetadata [HandshakeMetadata ]]
210+ ], # noqa: E501
210211 ) -> None :
211212 """
212213 Either return immediately or establish a websocket connection and return
@@ -218,16 +219,11 @@ async def ensure_connected(
218219 logger .info ("Attempting to establish new ws connection" )
219220
220221 last_error : Exception | None = None
221- for i in range (max_retry ):
222+ i = 0
223+ while rate_limiter .has_budget_or_throw (client_id , ERROR_HANDSHAKE , last_error ):
222224 if i > 0 :
223225 logger .info (f"Retrying build handshake number { i } times" )
224- if not rate_limiter .has_budget (client_id ):
225- logger .debug ("No retry budget for %s." , client_id )
226- raise HandshakeBudgetExhaustedException (
227- ERROR_HANDSHAKE ,
228- "No retry budget" ,
229- client_id = client_id ,
230- ) from last_error
226+ i += 1
231227
232228 rate_limiter .consume_budget (client_id )
233229
@@ -238,10 +234,12 @@ async def ensure_connected(
238234 try :
239235 try :
240236 expectedSessionState = ExpectedSessionState (
241- nextExpectedSeq = 0 ,
242- nextSentSeq = 0 ,
237+ nextExpectedSeq = self . ack ,
238+ nextSentSeq = self . seq ,
243239 )
244- handshake_request = ControlMessageHandshakeRequest [Any ](
240+ handshake_request = ControlMessageHandshakeRequest [
241+ HandshakeMetadata
242+ ]( # noqa: E501
245243 type = "HANDSHAKE_REQ" ,
246244 protocolVersion = PROTOCOL_VERSION ,
247245 sessionId = self .session_id ,
@@ -253,85 +251,68 @@ async def ensure_connected(
253251 async def websocket_closed_callback () -> None :
254252 logger .error ("websocket closed before handshake response" )
255253
254+ await send_transport_message (
255+ TransportMessage (
256+ from_ = self ._transport_id ,
257+ to = self ._to_id ,
258+ streamId = stream_id ,
259+ controlFlags = 0 ,
260+ id = nanoid .generate (),
261+ seq = 0 ,
262+ ack = 0 ,
263+ payload = handshake_request .model_dump (),
264+ ),
265+ ws = ws ,
266+ websocket_closed_callback = websocket_closed_callback ,
267+ )
268+ except (
269+ WebsocketClosedException ,
270+ FailedSendingMessageException ,
271+ ) as e : # noqa: E501
272+ raise RiverException (
273+ ERROR_HANDSHAKE ,
274+ "Handshake failed, conn closed while sending response" , # noqa: E501
275+ ) from e
276+
277+ startup_grace_deadline_ms = await self ._get_current_time () + 60_000
278+ while True :
279+ if await self ._get_current_time () >= startup_grace_deadline_ms : # noqa: E501
280+ raise RiverException (
281+ ERROR_HANDSHAKE ,
282+ "Handshake response timeout, closing connection" , # noqa: E501
283+ )
256284 try :
257- payload = handshake_request .model_dump ()
258- await send_transport_message (
259- TransportMessage (
260- from_ = self ._transport_id ,
261- to = self ._to_id ,
262- streamId = stream_id ,
263- controlFlags = 0 ,
264- id = nanoid .generate (),
265- seq = 0 ,
266- ack = 0 ,
267- payload = payload ,
268- ),
269- ws = ws ,
270- websocket_closed_callback = websocket_closed_callback ,
285+ data = await ws .recv ()
286+ except ConnectionClosed as e :
287+ logger .debug (
288+ "Connection closed during waiting for handshake response" , # noqa: E501
289+ exc_info = True ,
271290 )
272- except (
273- WebsocketClosedException ,
274- FailedSendingMessageException ,
275- ) as e : # noqa: E501
276291 raise RiverException (
277292 ERROR_HANDSHAKE ,
278- "Handshake failed, conn closed while sending response" , # noqa: E501
293+ "Handshake failed, conn closed while waiting for response" , # noqa: E501
294+ ) from e
295+ try :
296+ response_msg = parse_transport_msg (data )
297+ break
298+ except IgnoreMessageException :
299+ logger .debug ("Ignoring transport message" , exc_info = True ) # noqa: E501
300+ continue
301+ except InvalidMessageException as e :
302+ raise RiverException (
303+ ERROR_HANDSHAKE ,
304+ "Got invalid transport message, closing connection" ,
279305 ) from e
280- except FailedSendingMessageException as e :
281- raise RiverException (
282- ERROR_CODE_STREAM_CLOSED ,
283- "Stream closed before response, closing connection" ,
284- ) from e
285306
286- startup_grace_deadline_ms = await self ._get_current_time () + 60_000
287307 try :
288- while True :
289- if (
290- await self ._get_current_time ()
291- >= startup_grace_deadline_ms
292- ): # noqa: E501
293- raise RiverException (
294- ERROR_HANDSHAKE ,
295- "Handshake response timeout, closing connection" , # noqa: E501
296- )
297- try :
298- data = await ws .recv ()
299- except ConnectionClosed as e :
300- logger .debug (
301- "Connection closed during waiting for handshake response" , # noqa: E501
302- exc_info = True ,
303- )
304- raise RiverException (
305- ERROR_HANDSHAKE ,
306- "Handshake failed, conn closed while waiting for response" , # noqa: E501
307- ) from e
308- try :
309- response_msg = parse_transport_msg (data )
310- break
311- except IgnoreMessageException :
312- logger .debug (
313- "Ignoring transport message" , exc_info = True
314- ) # noqa: E501
315- continue
316- except InvalidMessageException as e :
317- raise RiverException (
318- ERROR_HANDSHAKE ,
319- "Got invalid transport message, closing connection" ,
320- ) from e
321-
322308 handshake_response = ControlMessageHandshakeResponse (
323309 ** response_msg .payload
324- ) # noqa: E501
310+ )
325311 logger .debug ("river client waiting for handshake response" )
326312 except ValidationError as e :
327313 raise RiverException (
328314 ERROR_HANDSHAKE , "Failed to parse handshake response"
329315 ) from e
330- except asyncio .TimeoutError as e :
331- raise RiverException (
332- ERROR_HANDSHAKE ,
333- "Handshake response timeout, closing connection" , # noqa: E501
334- ) from e
335316
336317 logger .debug (
337318 "river client get handshake response : %r" , handshake_response
0 commit comments