1- import asyncio
21import logging
32from collections .abc import Awaitable , Callable
4- from typing import Generic , assert_never
3+ from typing import Generic
54
65import nanoid
7- import websockets
8- import websockets .asyncio .client
9- from pydantic import ValidationError
10- from websockets .asyncio .client import ClientConnection
11- from websockets .exceptions import ConnectionClosed
126
137from replit_river .error_schema import (
14- ERROR_CODE_STREAM_CLOSED ,
15- ERROR_HANDSHAKE ,
168 RiverException ,
179)
18- from replit_river .messages import (
19- FailedSendingMessageException ,
20- WebsocketClosedException ,
21- parse_transport_msg ,
22- send_transport_message ,
23- )
2410from replit_river .rate_limiter import LeakyBucketRateLimit
25- from replit_river .rpc import (
26- SESSION_MISMATCH_CODE ,
27- ControlMessageHandshakeRequest ,
28- ControlMessageHandshakeResponse ,
29- ExpectedSessionState ,
30- TransportMessage ,
31- )
32- from replit_river .seq_manager import (
33- IgnoreMessageException ,
34- InvalidMessageException ,
35- )
3611from replit_river .transport_options import (
3712 HandshakeMetadataType ,
3813 TransportOptions ,
@@ -78,33 +53,25 @@ def __init__(
7853 transport_options .connection_retry_options
7954 )
8055
81- async def _close_session (self ) -> None :
82- logger .info (f"start closing session { self ._transport_id } " )
83- if not self ._session :
84- return
85- await self ._session .close ()
86- logger .info (f"Transport closed { self ._transport_id } " )
87-
88- def generate_nanoid (self ) -> str :
89- return str (nanoid .generate ())
90-
9156 async def close (self ) -> None :
9257 self ._rate_limiter .close ()
93- await self ._close_session ()
58+ if self ._session :
59+ logger .info (f"start closing session { self ._transport_id } " )
60+ await self ._session .close ()
61+ logger .info (f"Transport closed { self ._transport_id } " )
9462
9563 async def get_or_create_session (self ) -> Session :
9664 """
97- If we have an active session, return it.
98- If we have a "closed" session, mint a whole new session.
99- If we have a disconnected session, attempt to start a new WS and use it.
65+ Create a session if it does not exist,
66+ call ensure_connected on whatever session is active.
10067 """
10168 existing_session = self ._session
102- if not existing_session :
69+ if not existing_session or not existing_session . is_session_open () :
10370 logger .info ("Creating new session" )
10471 new_session = Session (
10572 transport_id = self ._transport_id ,
10673 to_id = self ._server_id ,
107- session_id = self . generate_nanoid (),
74+ session_id = nanoid . generate (),
10875 transport_options = self ._transport_options ,
10976 close_session_callback = self ._delete_session ,
11077 retry_connection_callback = self ._retry_connection ,
@@ -121,214 +88,12 @@ async def get_or_create_session(self) -> Session:
12188 )
12289 return existing_session
12390
124- async def _establish_new_connection (
125- self ,
126- old_session : Session | None = None ,
127- ) -> tuple [
128- ClientConnection ,
129- ControlMessageHandshakeRequest [HandshakeMetadataType ],
130- ControlMessageHandshakeResponse ,
131- ]:
132- """Build a new websocket connection with retry logic."""
133- rate_limit = self ._rate_limiter
134- max_retry = self ._transport_options .connection_retry_options .max_retry
135- client_id = self ._client_id
136- logger .info ("Attempting to establish new ws connection" )
137-
138- last_error : Exception | None = None
139- for i in range (max_retry ):
140- if i > 0 :
141- logger .info (f"Retrying build handshake number { i } times" )
142- if not rate_limit .has_budget (client_id ):
143- logger .debug ("No retry budget for %s." , client_id )
144- raise HandshakeBudgetExhaustedException (
145- ERROR_HANDSHAKE ,
146- "No retry budget" ,
147- client_id = client_id ,
148- ) from last_error
149-
150- rate_limit .consume_budget (client_id )
151-
152- # if the session is closed, we shouldn't use it
153- if old_session and not old_session .is_session_open ():
154- old_session = None
155-
156- try :
157- uri_and_metadata = await self ._uri_and_metadata_factory ()
158- ws = await websockets .asyncio .client .connect (uri_and_metadata ["uri" ])
159- session_id : str
160- if old_session :
161- session_id = old_session .session_id
162- else :
163- session_id = self .generate_nanoid ()
164-
165- try :
166- (
167- handshake_request ,
168- handshake_response ,
169- ) = await self ._establish_handshake (
170- session_id ,
171- uri_and_metadata ["metadata" ],
172- ws ,
173- old_session ,
174- )
175- rate_limit .start_restoring_budget (client_id )
176- return ws , handshake_request , handshake_response
177- except RiverException as e :
178- await ws .close ()
179- raise e
180- except Exception as e :
181- last_error = e
182- backoff_time = rate_limit .get_backoff_ms (client_id )
183- logger .exception (
184- f"Error connecting, retrying with { backoff_time } ms backoff"
185- )
186- await asyncio .sleep (backoff_time / 1000 )
187-
188- raise RiverException (
189- ERROR_HANDSHAKE ,
190- f"Failed to create ws after retrying { max_retry } number of times" ,
191- ) from last_error
192-
19391 async def _retry_connection (self ) -> Session :
194- if not self ._transport_options .transparent_reconnect :
195- await self ._close_session ()
92+ if not self ._transport_options .transparent_reconnect and self ._session :
93+ logger .info ("transparent_reconnect not set, closing {self._transport_id}" )
94+ await self ._session .close ()
19695 return await self .get_or_create_session ()
19796
198- async def _send_handshake_request (
199- self ,
200- session_id : str ,
201- handshake_metadata : HandshakeMetadataType | None ,
202- websocket : ClientConnection ,
203- expected_session_state : ExpectedSessionState ,
204- ) -> ControlMessageHandshakeRequest [HandshakeMetadataType ]:
205- handshake_request = ControlMessageHandshakeRequest [HandshakeMetadataType ](
206- type = "HANDSHAKE_REQ" ,
207- protocolVersion = PROTOCOL_VERSION ,
208- sessionId = session_id ,
209- metadata = handshake_metadata ,
210- expectedSessionState = expected_session_state ,
211- )
212- stream_id = self .generate_nanoid ()
213-
214- async def websocket_closed_callback () -> None :
215- logger .error ("websocket closed before handshake response" )
216-
217- try :
218- payload = handshake_request .model_dump ()
219- await send_transport_message (
220- TransportMessage (
221- from_ = self ._transport_id ,
222- to = self ._server_id ,
223- streamId = stream_id ,
224- controlFlags = 0 ,
225- id = self .generate_nanoid (),
226- seq = 0 ,
227- ack = 0 ,
228- payload = payload ,
229- ),
230- ws = websocket ,
231- websocket_closed_callback = websocket_closed_callback ,
232- )
233- return handshake_request
234- except (WebsocketClosedException , FailedSendingMessageException ) as e :
235- raise RiverException (
236- ERROR_HANDSHAKE , "Handshake failed, conn closed while sending response"
237- ) from e
238-
239- async def _get_handshake_response_msg (
240- self , websocket : ClientConnection
241- ) -> TransportMessage :
242- while True :
243- try :
244- data = await websocket .recv ()
245- except ConnectionClosed as e :
246- logger .debug (
247- "Connection closed during waiting for handshake response" ,
248- exc_info = True ,
249- )
250- raise RiverException (
251- ERROR_HANDSHAKE ,
252- "Handshake failed, conn closed while waiting for response" ,
253- ) from e
254- try :
255- return parse_transport_msg (data )
256- except IgnoreMessageException :
257- logger .debug ("Ignoring transport message" , exc_info = True )
258- continue
259- except InvalidMessageException as e :
260- raise RiverException (
261- ERROR_HANDSHAKE ,
262- "Got invalid transport message, closing connection" ,
263- ) from e
264-
265- async def _establish_handshake (
266- self ,
267- session_id : str ,
268- handshake_metadata : HandshakeMetadataType ,
269- websocket : ClientConnection ,
270- old_session : Session | None ,
271- ) -> tuple [
272- ControlMessageHandshakeRequest [HandshakeMetadataType ],
273- ControlMessageHandshakeResponse ,
274- ]:
275- try :
276- expectedSessionState : ExpectedSessionState
277- match old_session :
278- case None :
279- expectedSessionState = ExpectedSessionState (
280- nextExpectedSeq = 0 ,
281- nextSentSeq = 0 ,
282- )
283- case Session ():
284- expectedSessionState = ExpectedSessionState (
285- nextExpectedSeq = old_session .ack ,
286- nextSentSeq = old_session .seq ,
287- )
288- case other :
289- assert_never (other )
290- handshake_request = await self ._send_handshake_request (
291- session_id = session_id ,
292- handshake_metadata = handshake_metadata ,
293- websocket = websocket ,
294- expected_session_state = expectedSessionState ,
295- )
296- except FailedSendingMessageException as e :
297- raise RiverException (
298- ERROR_CODE_STREAM_CLOSED ,
299- "Stream closed before response, closing connection" ,
300- ) from e
301-
302- startup_grace_sec = 60
303- try :
304- response_msg = await asyncio .wait_for (
305- self ._get_handshake_response_msg (websocket ), startup_grace_sec
306- )
307- handshake_response = ControlMessageHandshakeResponse (** response_msg .payload )
308- logger .debug ("river client waiting for handshake response" )
309- except ValidationError as e :
310- raise RiverException (
311- ERROR_HANDSHAKE , "Failed to parse handshake response"
312- ) from e
313- except asyncio .TimeoutError as e :
314- raise RiverException (
315- ERROR_HANDSHAKE , "Handshake response timeout, closing connection"
316- ) from e
317-
318- logger .debug ("river client get handshake response : %r" , handshake_response )
319- if not handshake_response .status .ok :
320- if old_session and handshake_response .status .code == SESSION_MISMATCH_CODE :
321- # If the session status is mismatched, we should close the old session
322- # and let the retry logic to create a new session.
323- await old_session .close ()
324-
325- raise RiverException (
326- ERROR_HANDSHAKE ,
327- f"Handshake failed with code ${ handshake_response .status .code } : "
328- + f"{ handshake_response .status .reason } " ,
329- )
330- return handshake_request , handshake_response
331-
33297 async def _delete_session (self , session : Session ) -> None :
33398 if self ._session and session ._to_id == self ._session ._to_id :
33499 self ._session = None
0 commit comments