Skip to content

Commit 34ef3f5

Browse files
More lifecycle management
1 parent 3cd44a2 commit 34ef3f5

File tree

2 files changed

+23
-258
lines changed

2 files changed

+23
-258
lines changed
Lines changed: 12 additions & 247 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,13 @@
1-
import asyncio
21
import logging
32
from collections.abc import Awaitable, Callable
4-
from typing import Generic, assert_never
3+
from typing import Generic
54

65
import nanoid
7-
import websockets
8-
import websockets.asyncio.client
9-
from pydantic import ValidationError
10-
from websockets.asyncio.client import ClientConnection
11-
from websockets.exceptions import ConnectionClosed
126

137
from replit_river.error_schema import (
14-
ERROR_CODE_STREAM_CLOSED,
15-
ERROR_HANDSHAKE,
168
RiverException,
179
)
18-
from replit_river.messages import (
19-
FailedSendingMessageException,
20-
WebsocketClosedException,
21-
parse_transport_msg,
22-
send_transport_message,
23-
)
2410
from replit_river.rate_limiter import LeakyBucketRateLimit
25-
from replit_river.rpc import (
26-
SESSION_MISMATCH_CODE,
27-
ControlMessageHandshakeRequest,
28-
ControlMessageHandshakeResponse,
29-
ExpectedSessionState,
30-
TransportMessage,
31-
)
32-
from replit_river.seq_manager import (
33-
IgnoreMessageException,
34-
InvalidMessageException,
35-
)
3611
from replit_river.transport_options import (
3712
HandshakeMetadataType,
3813
TransportOptions,
@@ -78,33 +53,25 @@ def __init__(
7853
transport_options.connection_retry_options
7954
)
8055

81-
async def _close_session(self) -> None:
82-
logger.info(f"start closing session {self._transport_id}")
83-
if not self._session:
84-
return
85-
await self._session.close()
86-
logger.info(f"Transport closed {self._transport_id}")
87-
88-
def generate_nanoid(self) -> str:
89-
return str(nanoid.generate())
90-
9156
async def close(self) -> None:
9257
self._rate_limiter.close()
93-
await self._close_session()
58+
if self._session:
59+
logger.info(f"start closing session {self._transport_id}")
60+
await self._session.close()
61+
logger.info(f"Transport closed {self._transport_id}")
9462

9563
async def get_or_create_session(self) -> Session:
9664
"""
97-
If we have an active session, return it.
98-
If we have a "closed" session, mint a whole new session.
99-
If we have a disconnected session, attempt to start a new WS and use it.
65+
Create a session if it does not exist,
66+
call ensure_connected on whatever session is active.
10067
"""
10168
existing_session = self._session
102-
if not existing_session:
69+
if not existing_session or not existing_session.is_session_open():
10370
logger.info("Creating new session")
10471
new_session = Session(
10572
transport_id=self._transport_id,
10673
to_id=self._server_id,
107-
session_id=self.generate_nanoid(),
74+
session_id=nanoid.generate(),
10875
transport_options=self._transport_options,
10976
close_session_callback=self._delete_session,
11077
retry_connection_callback=self._retry_connection,
@@ -121,214 +88,12 @@ async def get_or_create_session(self) -> Session:
12188
)
12289
return existing_session
12390

124-
async def _establish_new_connection(
125-
self,
126-
old_session: Session | None = None,
127-
) -> tuple[
128-
ClientConnection,
129-
ControlMessageHandshakeRequest[HandshakeMetadataType],
130-
ControlMessageHandshakeResponse,
131-
]:
132-
"""Build a new websocket connection with retry logic."""
133-
rate_limit = self._rate_limiter
134-
max_retry = self._transport_options.connection_retry_options.max_retry
135-
client_id = self._client_id
136-
logger.info("Attempting to establish new ws connection")
137-
138-
last_error: Exception | None = None
139-
for i in range(max_retry):
140-
if i > 0:
141-
logger.info(f"Retrying build handshake number {i} times")
142-
if not rate_limit.has_budget(client_id):
143-
logger.debug("No retry budget for %s.", client_id)
144-
raise HandshakeBudgetExhaustedException(
145-
ERROR_HANDSHAKE,
146-
"No retry budget",
147-
client_id=client_id,
148-
) from last_error
149-
150-
rate_limit.consume_budget(client_id)
151-
152-
# if the session is closed, we shouldn't use it
153-
if old_session and not old_session.is_session_open():
154-
old_session = None
155-
156-
try:
157-
uri_and_metadata = await self._uri_and_metadata_factory()
158-
ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"])
159-
session_id: str
160-
if old_session:
161-
session_id = old_session.session_id
162-
else:
163-
session_id = self.generate_nanoid()
164-
165-
try:
166-
(
167-
handshake_request,
168-
handshake_response,
169-
) = await self._establish_handshake(
170-
session_id,
171-
uri_and_metadata["metadata"],
172-
ws,
173-
old_session,
174-
)
175-
rate_limit.start_restoring_budget(client_id)
176-
return ws, handshake_request, handshake_response
177-
except RiverException as e:
178-
await ws.close()
179-
raise e
180-
except Exception as e:
181-
last_error = e
182-
backoff_time = rate_limit.get_backoff_ms(client_id)
183-
logger.exception(
184-
f"Error connecting, retrying with {backoff_time}ms backoff"
185-
)
186-
await asyncio.sleep(backoff_time / 1000)
187-
188-
raise RiverException(
189-
ERROR_HANDSHAKE,
190-
f"Failed to create ws after retrying {max_retry} number of times",
191-
) from last_error
192-
19391
async def _retry_connection(self) -> Session:
194-
if not self._transport_options.transparent_reconnect:
195-
await self._close_session()
92+
if not self._transport_options.transparent_reconnect and self._session:
93+
logger.info("transparent_reconnect not set, closing {self._transport_id}")
94+
await self._session.close()
19695
return await self.get_or_create_session()
19796

198-
async def _send_handshake_request(
199-
self,
200-
session_id: str,
201-
handshake_metadata: HandshakeMetadataType | None,
202-
websocket: ClientConnection,
203-
expected_session_state: ExpectedSessionState,
204-
) -> ControlMessageHandshakeRequest[HandshakeMetadataType]:
205-
handshake_request = ControlMessageHandshakeRequest[HandshakeMetadataType](
206-
type="HANDSHAKE_REQ",
207-
protocolVersion=PROTOCOL_VERSION,
208-
sessionId=session_id,
209-
metadata=handshake_metadata,
210-
expectedSessionState=expected_session_state,
211-
)
212-
stream_id = self.generate_nanoid()
213-
214-
async def websocket_closed_callback() -> None:
215-
logger.error("websocket closed before handshake response")
216-
217-
try:
218-
payload = handshake_request.model_dump()
219-
await send_transport_message(
220-
TransportMessage(
221-
from_=self._transport_id,
222-
to=self._server_id,
223-
streamId=stream_id,
224-
controlFlags=0,
225-
id=self.generate_nanoid(),
226-
seq=0,
227-
ack=0,
228-
payload=payload,
229-
),
230-
ws=websocket,
231-
websocket_closed_callback=websocket_closed_callback,
232-
)
233-
return handshake_request
234-
except (WebsocketClosedException, FailedSendingMessageException) as e:
235-
raise RiverException(
236-
ERROR_HANDSHAKE, "Handshake failed, conn closed while sending response"
237-
) from e
238-
239-
async def _get_handshake_response_msg(
240-
self, websocket: ClientConnection
241-
) -> TransportMessage:
242-
while True:
243-
try:
244-
data = await websocket.recv()
245-
except ConnectionClosed as e:
246-
logger.debug(
247-
"Connection closed during waiting for handshake response",
248-
exc_info=True,
249-
)
250-
raise RiverException(
251-
ERROR_HANDSHAKE,
252-
"Handshake failed, conn closed while waiting for response",
253-
) from e
254-
try:
255-
return parse_transport_msg(data)
256-
except IgnoreMessageException:
257-
logger.debug("Ignoring transport message", exc_info=True)
258-
continue
259-
except InvalidMessageException as e:
260-
raise RiverException(
261-
ERROR_HANDSHAKE,
262-
"Got invalid transport message, closing connection",
263-
) from e
264-
265-
async def _establish_handshake(
266-
self,
267-
session_id: str,
268-
handshake_metadata: HandshakeMetadataType,
269-
websocket: ClientConnection,
270-
old_session: Session | None,
271-
) -> tuple[
272-
ControlMessageHandshakeRequest[HandshakeMetadataType],
273-
ControlMessageHandshakeResponse,
274-
]:
275-
try:
276-
expectedSessionState: ExpectedSessionState
277-
match old_session:
278-
case None:
279-
expectedSessionState = ExpectedSessionState(
280-
nextExpectedSeq=0,
281-
nextSentSeq=0,
282-
)
283-
case Session():
284-
expectedSessionState = ExpectedSessionState(
285-
nextExpectedSeq=old_session.ack,
286-
nextSentSeq=old_session.seq,
287-
)
288-
case other:
289-
assert_never(other)
290-
handshake_request = await self._send_handshake_request(
291-
session_id=session_id,
292-
handshake_metadata=handshake_metadata,
293-
websocket=websocket,
294-
expected_session_state=expectedSessionState,
295-
)
296-
except FailedSendingMessageException as e:
297-
raise RiverException(
298-
ERROR_CODE_STREAM_CLOSED,
299-
"Stream closed before response, closing connection",
300-
) from e
301-
302-
startup_grace_sec = 60
303-
try:
304-
response_msg = await asyncio.wait_for(
305-
self._get_handshake_response_msg(websocket), startup_grace_sec
306-
)
307-
handshake_response = ControlMessageHandshakeResponse(**response_msg.payload)
308-
logger.debug("river client waiting for handshake response")
309-
except ValidationError as e:
310-
raise RiverException(
311-
ERROR_HANDSHAKE, "Failed to parse handshake response"
312-
) from e
313-
except asyncio.TimeoutError as e:
314-
raise RiverException(
315-
ERROR_HANDSHAKE, "Handshake response timeout, closing connection"
316-
) from e
317-
318-
logger.debug("river client get handshake response : %r", handshake_response)
319-
if not handshake_response.status.ok:
320-
if old_session and handshake_response.status.code == SESSION_MISMATCH_CODE:
321-
# If the session status is mismatched, we should close the old session
322-
# and let the retry logic to create a new session.
323-
await old_session.close()
324-
325-
raise RiverException(
326-
ERROR_HANDSHAKE,
327-
f"Handshake failed with code ${handshake_response.status.code}: "
328-
+ f"{handshake_response.status.reason}",
329-
)
330-
return handshake_request, handshake_response
331-
33297
async def _delete_session(self, session: Session) -> None:
33398
if self._session and session._to_id == self._session._to_id:
33499
self._session = None

src/replit_river/v2/session.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def __init__(
139139

140140
# ws state
141141
self._ws_connected = False
142+
self._ws_unwrapped = None
142143
self._heartbeat_misses = 0
143144
self._retry_connection_callback = retry_connection_callback
144145

@@ -487,9 +488,6 @@ async def close(self) -> None:
487488
# invocation, so let's await this close to avoid dropping the socket.
488489
await self._ws_unwrapped.close()
489490

490-
# Clear the session in transports
491-
await self._close_session_callback(self)
492-
493491
# TODO: unexpected_close should close stream differently here to
494492
# throw exception correctly.
495493
for stream in self._streams.values():
@@ -498,6 +496,10 @@ async def close(self) -> None:
498496

499497
self._state = SessionState.CLOSED
500498

499+
# Clear the session in transports
500+
# This will get us GC'd, so this should be the last thing.
501+
await self._close_session_callback(self)
502+
501503
async def start_serve_responses(self) -> None:
502504
self._task_manager.create_task(self._serve())
503505

@@ -528,20 +530,16 @@ async def _serve(self) -> None:
528530
)
529531

530532
async def _handle_messages_from_ws(self) -> None:
531-
while self._ws_unwrapped is None:
533+
while self._ws_unwrapped is None or not self._ws_connected:
532534
await asyncio.sleep(1)
533535
logger.debug(
534536
"%s start handling messages from ws %s",
535537
"client",
536538
self._ws_unwrapped.id,
537539
)
538540
try:
539-
ws = self._ws_unwrapped
540-
while True:
541-
if not self._ws_unwrapped:
542-
# We should not process messages if the websocket is closed.
543-
break
544-
541+
# We should not process messages if the websocket is closed.
542+
while ws := self._ws_unwrapped:
545543
# decode=False: Avoiding an unnecessary round-trip through str
546544
# Ideally this should be type-ascripted to : bytes, but there is no
547545
# @overrides in `websockets` to hint this.
@@ -573,14 +571,16 @@ async def _handle_messages_from_ws(self) -> None:
573571
# Set our next expected ack number
574572
self.ack = msg.seq + 1
575573

576-
# Discard old messages from the buffer
574+
# Discard old server-ack'd messages from the ack buffer
577575
while self._ack_buffer and self._ack_buffer[0].seq < msg.ack:
578576
self._ack_buffer.popleft()
579577

580578
self._reset_session_close_countdown()
581579

580+
# Shortcut to avoid processing ack packets
582581
if msg.controlFlags & ACK_BIT != 0:
583582
continue
583+
584584
stream = self._streams.get(msg.streamId, None)
585585
if msg.controlFlags & STREAM_OPEN_BIT != 0:
586586
raise InvalidMessageException(

0 commit comments

Comments
 (0)