Skip to content

Commit 6d99316

Browse files
Inline ws creation
1 parent f0de172 commit 6d99316

File tree

3 files changed

+200
-54
lines changed

3 files changed

+200
-54
lines changed

src/replit_river/client_transport.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ async def get_or_create_session(self) -> ClientSession:
117117
return existing_session
118118
else:
119119
logger.info("Closing stale session %s", existing_session.session_id)
120+
await new_ws.close() # NB(dstewart): This wasn't there in the
121+
# v1 transport, were we just leaking WS?
120122
await existing_session.close()
121123
return await self._create_new_session()
122124

src/replit_river/v2/client_transport.py

Lines changed: 21 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from replit_river.error_schema import (
1414
ERROR_CODE_STREAM_CLOSED,
1515
ERROR_HANDSHAKE,
16-
ERROR_SESSION,
1716
RiverException,
1817
)
1918
from replit_river.messages import (
@@ -96,31 +95,28 @@ async def get_or_create_session(self) -> Session:
9695
If we have a disconnected session, attempt to start a new WS and use it.
9796
"""
9897
async with self._create_session_lock:
99-
existing_session = (
100-
self._session
101-
if self._session and self._session.is_session_open()
102-
else None
103-
)
104-
if existing_session is None:
105-
return await self._create_new_session()
106-
if existing_session.is_websocket_open():
107-
return existing_session
108-
new_ws, _, hs_response = await self._establish_new_connection(
109-
existing_session
110-
)
111-
if hs_response.status.sessionId == existing_session.session_id:
112-
logger.info(
113-
"Replacing ws connection in session id %s",
114-
existing_session.session_id,
98+
existing_session = self._session
99+
if not existing_session:
100+
logger.info("Creating new session")
101+
new_session = Session(
102+
transport_id=self._transport_id,
103+
to_id=self._server_id,
104+
session_id=self.generate_nanoid(),
105+
transport_options=self._transport_options,
106+
close_session_callback=self._delete_session,
107+
retry_connection_callback=self._retry_connection,
115108
)
116-
await existing_session.replace_with_new_websocket(new_ws)
117-
return existing_session
118-
else:
119-
logger.info("Closing stale session %s", existing_session.session_id)
120-
await new_ws.close() # NB(dstewart): This wasn't there in the
121-
# v1 transport, were we just leaking WS?
122-
await existing_session.close()
123-
return await self._create_new_session()
109+
110+
self._session = new_session
111+
existing_session = new_session
112+
await existing_session.start_serve_responses()
113+
114+
await existing_session.ensure_connected(
115+
client_id=self._client_id,
116+
rate_limiter=self._rate_limiter,
117+
uri_and_metadata_factory=self._uri_and_metadata_factory,
118+
)
119+
return existing_session
124120

125121
async def _establish_new_connection(
126122
self,
@@ -191,31 +187,6 @@ async def _establish_new_connection(
191187
f"Failed to create ws after retrying {max_retry} number of times",
192188
) from last_error
193189

194-
async def _create_new_session(
195-
self,
196-
) -> Session:
197-
logger.info("Creating new session")
198-
new_ws, hs_request, hs_response = await self._establish_new_connection()
199-
if not hs_response.status.ok:
200-
message = hs_response.status.reason
201-
raise RiverException(
202-
ERROR_SESSION,
203-
f"Server did not return OK status on handshake response: {message}",
204-
)
205-
new_session = Session(
206-
transport_id=self._transport_id,
207-
to_id=self._server_id,
208-
session_id=hs_request.sessionId,
209-
websocket=new_ws,
210-
transport_options=self._transport_options,
211-
close_session_callback=self._delete_session,
212-
retry_connection_callback=self._retry_connection,
213-
)
214-
215-
self._session = new_session
216-
await new_session.start_serve_responses()
217-
return new_session
218-
219190
async def _retry_connection(self) -> Session:
220191
if not self._transport_options.transparent_reconnect:
221192
await self._close_session()

src/replit_river/v2/session.py

Lines changed: 177 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
)
1616

1717
import nanoid # type: ignore
18+
import websockets.asyncio.client
1819
from aiochannel import Channel
1920
from aiochannel.errors import ChannelClosed
2021
from opentelemetry.trace import Span, use_span
2122
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
23+
from pydantic import ValidationError
2224
from websockets.asyncio.client import ClientConnection
2325
from websockets.exceptions import ConnectionClosed, ConnectionClosedOK
2426
from websockets.frames import CloseCode
@@ -33,6 +35,7 @@
3335
from replit_river.error_schema import (
3436
ERROR_CODE_CANCEL,
3537
ERROR_CODE_STREAM_CLOSED,
38+
ERROR_HANDSHAKE,
3639
RiverError,
3740
RiverException,
3841
RiverServiceException,
@@ -41,11 +44,17 @@
4144
)
4245
from replit_river.messages import (
4346
FailedSendingMessageException,
47+
WebsocketClosedException,
4448
parse_transport_msg,
49+
send_transport_message,
4550
)
51+
from replit_river.rate_limiter import LeakyBucketRateLimit
4652
from replit_river.rpc import (
4753
ACK_BIT,
4854
STREAM_OPEN_BIT,
55+
ControlMessageHandshakeRequest,
56+
ControlMessageHandshakeResponse,
57+
ExpectedSessionState,
4958
TransportMessage,
5059
TransportMessageTracingSetter,
5160
)
@@ -55,7 +64,15 @@
5564
OutOfOrderMessageException,
5665
)
5766
from replit_river.task_manager import BackgroundTaskManager
58-
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions
67+
from replit_river.transport_options import (
68+
MAX_MESSAGE_BUFFER_SIZE,
69+
TransportOptions,
70+
UriAndMetadata,
71+
)
72+
from replit_river.v2.client_transport import (
73+
PROTOCOL_VERSION,
74+
HandshakeBudgetExhaustedException,
75+
)
5976

6077
STREAM_CANCEL_BIT_TYPE = Literal[0b00100]
6178
STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100
@@ -107,7 +124,6 @@ def __init__(
107124
transport_id: str,
108125
to_id: str,
109126
session_id: str,
110-
websocket: ClientConnection,
111127
transport_options: TransportOptions,
112128
close_session_callback: CloseSessionCallback,
113129
retry_connection_callback: RetryConnectionCallback | None = None,
@@ -123,8 +139,7 @@ def __init__(
123139
self._close_session_after_time_secs: float | None = None
124140

125141
# ws state
126-
self._ws_connected = True
127-
self._ws_unwrapped = websocket
142+
self._ws_connected = False
128143
self._heartbeat_misses = 0
129144
self._retry_connection_callback = retry_connection_callback
130145

@@ -187,6 +202,164 @@ def get_next_pending() -> TransportMessage | None:
187202
)
188203
)
189204

205+
async def ensure_connected(
206+
self,
207+
client_id: str,
208+
rate_limiter: LeakyBucketRateLimit,
209+
uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]],
210+
) -> None:
211+
"""
212+
Either return immediately or establish a websocket connection and return
213+
once we can accept messages
214+
"""
215+
if self._ws_unwrapped and self._ws_connected:
216+
return
217+
max_retry = self._transport_options.connection_retry_options.max_retry
218+
logger.info("Attempting to establish new ws connection")
219+
220+
last_error: Exception | None = None
221+
for i in range(max_retry):
222+
if i > 0:
223+
logger.info(f"Retrying build handshake number {i} times")
224+
if not rate_limiter.has_budget(client_id):
225+
logger.debug("No retry budget for %s.", client_id)
226+
raise HandshakeBudgetExhaustedException(
227+
ERROR_HANDSHAKE,
228+
"No retry budget",
229+
client_id=client_id,
230+
) from last_error
231+
232+
rate_limiter.consume_budget(client_id)
233+
234+
try:
235+
uri_and_metadata = await uri_and_metadata_factory()
236+
ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"])
237+
238+
try:
239+
try:
240+
expectedSessionState = ExpectedSessionState(
241+
nextExpectedSeq=0,
242+
nextSentSeq=0,
243+
)
244+
handshake_request = ControlMessageHandshakeRequest[Any](
245+
type="HANDSHAKE_REQ",
246+
protocolVersion=PROTOCOL_VERSION,
247+
sessionId=self.session_id,
248+
metadata=uri_and_metadata["metadata"],
249+
expectedSessionState=expectedSessionState,
250+
)
251+
stream_id = nanoid.generate()
252+
253+
async def websocket_closed_callback() -> None:
254+
logger.error("websocket closed before handshake response")
255+
256+
try:
257+
payload = handshake_request.model_dump()
258+
await send_transport_message(
259+
TransportMessage(
260+
from_=self._transport_id,
261+
to=self._to_id,
262+
streamId=stream_id,
263+
controlFlags=0,
264+
id=nanoid.generate(),
265+
seq=0,
266+
ack=0,
267+
payload=payload,
268+
),
269+
ws=ws,
270+
websocket_closed_callback=websocket_closed_callback,
271+
)
272+
except (
273+
WebsocketClosedException,
274+
FailedSendingMessageException,
275+
) as e: # noqa: E501
276+
raise RiverException(
277+
ERROR_HANDSHAKE,
278+
"Handshake failed, conn closed while sending response", # noqa: E501
279+
) from e
280+
except FailedSendingMessageException as e:
281+
raise RiverException(
282+
ERROR_CODE_STREAM_CLOSED,
283+
"Stream closed before response, closing connection",
284+
) from e
285+
286+
startup_grace_deadline_ms = await self._get_current_time() + 60_000
287+
try:
288+
while True:
289+
if (
290+
await self._get_current_time()
291+
>= startup_grace_deadline_ms
292+
): # noqa: E501
293+
raise RiverException(
294+
ERROR_HANDSHAKE,
295+
"Handshake response timeout, closing connection", # noqa: E501
296+
)
297+
try:
298+
data = await ws.recv()
299+
except ConnectionClosed as e:
300+
logger.debug(
301+
"Connection closed during waiting for handshake response", # noqa: E501
302+
exc_info=True,
303+
)
304+
raise RiverException(
305+
ERROR_HANDSHAKE,
306+
"Handshake failed, conn closed while waiting for response", # noqa: E501
307+
) from e
308+
try:
309+
response_msg = parse_transport_msg(data)
310+
break
311+
except IgnoreMessageException:
312+
logger.debug(
313+
"Ignoring transport message", exc_info=True
314+
) # noqa: E501
315+
continue
316+
except InvalidMessageException as e:
317+
raise RiverException(
318+
ERROR_HANDSHAKE,
319+
"Got invalid transport message, closing connection",
320+
) from e
321+
322+
handshake_response = ControlMessageHandshakeResponse(
323+
**response_msg.payload
324+
) # noqa: E501
325+
logger.debug("river client waiting for handshake response")
326+
except ValidationError as e:
327+
raise RiverException(
328+
ERROR_HANDSHAKE, "Failed to parse handshake response"
329+
) from e
330+
except asyncio.TimeoutError as e:
331+
raise RiverException(
332+
ERROR_HANDSHAKE,
333+
"Handshake response timeout, closing connection", # noqa: E501
334+
) from e
335+
336+
logger.debug(
337+
"river client get handshake response : %r", handshake_response
338+
) # noqa: E501
339+
if not handshake_response.status.ok:
340+
raise RiverException(
341+
ERROR_HANDSHAKE,
342+
f"Handshake failed with code {handshake_response.status.code}: " # noqa: E501
343+
+ f"{handshake_response.status.reason}",
344+
)
345+
346+
rate_limiter.start_restoring_budget(client_id)
347+
except RiverException as e:
348+
await ws.close()
349+
raise e
350+
except Exception as e:
351+
last_error = e
352+
backoff_time = rate_limiter.get_backoff_ms(client_id)
353+
logger.exception(
354+
f"Error connecting, retrying with {backoff_time}ms backoff"
355+
)
356+
await asyncio.sleep(backoff_time / 1000)
357+
358+
raise RiverException(
359+
ERROR_HANDSHAKE,
360+
f"Failed to create ws after retrying {max_retry} number of times",
361+
) from last_error
362+
190363
def _setup_heartbeats_task(
191364
self,
192365
do_close_websocket: Callable[[], Awaitable[None]],

0 commit comments

Comments
 (0)