1515)
1616
1717import nanoid # type: ignore
18+ import websockets .asyncio .client
1819from aiochannel import Channel
1920from aiochannel .errors import ChannelClosed
2021from opentelemetry .trace import Span , use_span
2122from opentelemetry .trace .propagation .tracecontext import TraceContextTextMapPropagator
23+ from pydantic import ValidationError
2224from websockets .asyncio .client import ClientConnection
2325from websockets .exceptions import ConnectionClosed , ConnectionClosedOK
2426from websockets .frames import CloseCode
3335from replit_river .error_schema import (
3436 ERROR_CODE_CANCEL ,
3537 ERROR_CODE_STREAM_CLOSED ,
38+ ERROR_HANDSHAKE ,
3639 RiverError ,
3740 RiverException ,
3841 RiverServiceException ,
4144)
4245from replit_river .messages import (
4346 FailedSendingMessageException ,
47+ WebsocketClosedException ,
4448 parse_transport_msg ,
49+ send_transport_message ,
4550)
51+ from replit_river .rate_limiter import LeakyBucketRateLimit
4652from replit_river .rpc import (
4753 ACK_BIT ,
4854 STREAM_OPEN_BIT ,
55+ ControlMessageHandshakeRequest ,
56+ ControlMessageHandshakeResponse ,
57+ ExpectedSessionState ,
4958 TransportMessage ,
5059 TransportMessageTracingSetter ,
5160)
5564 OutOfOrderMessageException ,
5665)
5766from replit_river .task_manager import BackgroundTaskManager
58- from replit_river .transport_options import MAX_MESSAGE_BUFFER_SIZE , TransportOptions
67+ from replit_river .transport_options import (
68+ MAX_MESSAGE_BUFFER_SIZE ,
69+ TransportOptions ,
70+ UriAndMetadata ,
71+ )
72+ from replit_river .v2 .client_transport import (
73+ PROTOCOL_VERSION ,
74+ HandshakeBudgetExhaustedException ,
75+ )
5976
6077STREAM_CANCEL_BIT_TYPE = Literal [0b00100 ]
6178STREAM_CANCEL_BIT : STREAM_CANCEL_BIT_TYPE = 0b00100
@@ -107,7 +124,6 @@ def __init__(
107124 transport_id : str ,
108125 to_id : str ,
109126 session_id : str ,
110- websocket : ClientConnection ,
111127 transport_options : TransportOptions ,
112128 close_session_callback : CloseSessionCallback ,
113129 retry_connection_callback : RetryConnectionCallback | None = None ,
@@ -123,8 +139,7 @@ def __init__(
123139 self ._close_session_after_time_secs : float | None = None
124140
125141 # ws state
126- self ._ws_connected = True
127- self ._ws_unwrapped = websocket
142+ self ._ws_connected = False
128143 self ._heartbeat_misses = 0
129144 self ._retry_connection_callback = retry_connection_callback
130145
@@ -187,6 +202,164 @@ def get_next_pending() -> TransportMessage | None:
187202 )
188203 )
189204
205+ async def ensure_connected (
206+ self ,
207+ client_id : str ,
208+ rate_limiter : LeakyBucketRateLimit ,
209+ uri_and_metadata_factory : Callable [[], Awaitable [UriAndMetadata ]],
210+ ) -> None :
211+ """
212+ Either return immediately or establish a websocket connection and return
213+ once we can accept messages
214+ """
215+ if self ._ws_unwrapped and self ._ws_connected :
216+ return
217+ max_retry = self ._transport_options .connection_retry_options .max_retry
218+ logger .info ("Attempting to establish new ws connection" )
219+
220+ last_error : Exception | None = None
221+ for i in range (max_retry ):
222+ if i > 0 :
223+ logger .info (f"Retrying build handshake number { i } times" )
224+ if not rate_limiter .has_budget (client_id ):
225+ logger .debug ("No retry budget for %s." , client_id )
226+ raise HandshakeBudgetExhaustedException (
227+ ERROR_HANDSHAKE ,
228+ "No retry budget" ,
229+ client_id = client_id ,
230+ ) from last_error
231+
232+ rate_limiter .consume_budget (client_id )
233+
234+ try :
235+ uri_and_metadata = await uri_and_metadata_factory ()
236+ ws = await websockets .asyncio .client .connect (uri_and_metadata ["uri" ])
237+
238+ try :
239+ try :
240+ expectedSessionState = ExpectedSessionState (
241+ nextExpectedSeq = 0 ,
242+ nextSentSeq = 0 ,
243+ )
244+ handshake_request = ControlMessageHandshakeRequest [Any ](
245+ type = "HANDSHAKE_REQ" ,
246+ protocolVersion = PROTOCOL_VERSION ,
247+ sessionId = self .session_id ,
248+ metadata = uri_and_metadata ["metadata" ],
249+ expectedSessionState = expectedSessionState ,
250+ )
251+ stream_id = nanoid .generate ()
252+
253+ async def websocket_closed_callback () -> None :
254+ logger .error ("websocket closed before handshake response" )
255+
256+ try :
257+ payload = handshake_request .model_dump ()
258+ await send_transport_message (
259+ TransportMessage (
260+ from_ = self ._transport_id ,
261+ to = self ._to_id ,
262+ streamId = stream_id ,
263+ controlFlags = 0 ,
264+ id = nanoid .generate (),
265+ seq = 0 ,
266+ ack = 0 ,
267+ payload = payload ,
268+ ),
269+ ws = ws ,
270+ websocket_closed_callback = websocket_closed_callback ,
271+ )
272+ except (
273+ WebsocketClosedException ,
274+ FailedSendingMessageException ,
275+ ) as e : # noqa: E501
276+ raise RiverException (
277+ ERROR_HANDSHAKE ,
278+ "Handshake failed, conn closed while sending response" , # noqa: E501
279+ ) from e
280+ except FailedSendingMessageException as e :
281+ raise RiverException (
282+ ERROR_CODE_STREAM_CLOSED ,
283+ "Stream closed before response, closing connection" ,
284+ ) from e
285+
286+ startup_grace_deadline_ms = await self ._get_current_time () + 60_000
287+ try :
288+ while True :
289+ if (
290+ await self ._get_current_time ()
291+ >= startup_grace_deadline_ms
292+ ): # noqa: E501
293+ raise RiverException (
294+ ERROR_HANDSHAKE ,
295+ "Handshake response timeout, closing connection" , # noqa: E501
296+ )
297+ try :
298+ data = await ws .recv ()
299+ except ConnectionClosed as e :
300+ logger .debug (
301+ "Connection closed during waiting for handshake response" , # noqa: E501
302+ exc_info = True ,
303+ )
304+ raise RiverException (
305+ ERROR_HANDSHAKE ,
306+ "Handshake failed, conn closed while waiting for response" , # noqa: E501
307+ ) from e
308+ try :
309+ response_msg = parse_transport_msg (data )
310+ break
311+ except IgnoreMessageException :
312+ logger .debug (
313+ "Ignoring transport message" , exc_info = True
314+ ) # noqa: E501
315+ continue
316+ except InvalidMessageException as e :
317+ raise RiverException (
318+ ERROR_HANDSHAKE ,
319+ "Got invalid transport message, closing connection" ,
320+ ) from e
321+
322+ handshake_response = ControlMessageHandshakeResponse (
323+ ** response_msg .payload
324+ ) # noqa: E501
325+ logger .debug ("river client waiting for handshake response" )
326+ except ValidationError as e :
327+ raise RiverException (
328+ ERROR_HANDSHAKE , "Failed to parse handshake response"
329+ ) from e
330+ except asyncio .TimeoutError as e :
331+ raise RiverException (
332+ ERROR_HANDSHAKE ,
333+ "Handshake response timeout, closing connection" , # noqa: E501
334+ ) from e
335+
336+ logger .debug (
337+ "river client get handshake response : %r" , handshake_response
338+ ) # noqa: E501
339+ if not handshake_response .status .ok :
340+ raise RiverException (
341+ ERROR_HANDSHAKE ,
342+ f"Handshake failed with code { handshake_response .status .code } : " # noqa: E501
343+ + f"{ handshake_response .status .reason } " ,
344+ )
345+
346+ rate_limiter .start_restoring_budget (client_id )
347+ except RiverException as e :
348+ await ws .close ()
349+ raise e
350+ except Exception as e :
351+ last_error = e
352+ backoff_time = rate_limiter .get_backoff_ms (client_id )
353+ logger .exception (
354+ f"Error connecting, retrying with { backoff_time } ms backoff"
355+ )
356+ await asyncio .sleep (backoff_time / 1000 )
357+
358+ raise RiverException (
359+ ERROR_HANDSHAKE ,
360+ f"Failed to create ws after retrying { max_retry } number of times" ,
361+ ) from last_error
362+
190363 def _setup_heartbeats_task (
191364 self ,
192365 do_close_websocket : Callable [[], Awaitable [None ]],
0 commit comments