Skip to content

Commit 8208d2b

Browse files
PR feedback
1 parent b4b81b4 commit 8208d2b

File tree

1 file changed

+52
-35
lines changed

1 file changed

+52
-35
lines changed

src/replit_river/v2/session.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,22 @@ def get_next_sent_seq() -> int:
248248
return self.seq
249249

250250
def close_session(reason: Exception | None) -> None:
251+
# If we're already closing, just let whoever's currently doing it handle it.
252+
if self._state in TerminalStates:
253+
return
254+
251255
# Avoid closing twice
252256
if self._terminating_task is None:
257+
current_state = self._state
258+
self._state = SessionState.CLOSING
259+
253260
# We can't just call self.close() directly because
254261
# we're inside a thread that will eventually be awaited
255262
# during the cleanup procedure.
256-
self._terminating_task = asyncio.create_task(self.close(reason))
263+
264+
self._terminating_task = asyncio.create_task(
265+
self.close(reason, current_state=current_state),
266+
)
257267

258268
def transition_connecting() -> None:
259269
if self._state in TerminalStates:
@@ -301,6 +311,7 @@ def unbind_connecting_task() -> None:
301311
get_next_sent_seq=get_next_sent_seq,
302312
get_current_ack=lambda: self.ack,
303313
get_current_time=self._get_current_time,
314+
get_state=lambda: self._state,
304315
transition_connecting=transition_connecting,
305316
close_ws_in_background=close_ws_in_background,
306317
transition_connected=transition_connected,
@@ -385,12 +396,12 @@ async def _enqueue_message(
385396
# Wake up buffered_message_sender
386397
self._process_messages.set()
387398

388-
async def close(self, reason: Exception | None = None) -> None:
399+
async def close(self, reason: Exception | None = None, current_state: SessionState | None = None ) -> None:
389400
"""Close the session and all associated streams."""
390401
logger.info(
391402
f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}"
392403
)
393-
if self._state in TerminalStates:
404+
if (current_state or self._state) in TerminalStates:
394405
# already closing
395406
return
396407
self._state = SessionState.CLOSING
@@ -987,6 +998,7 @@ async def _do_ensure_connected[HandshakeMetadata](
987998
get_current_time: Callable[[], Awaitable[float]],
988999
get_next_sent_seq: Callable[[], int],
9891000
get_current_ack: Callable[[], int],
1001+
get_state: Callable[[], SessionState],
9901002
transition_connecting: Callable[[], None],
9911003
close_ws_in_background: Callable[[ClientConnection], None],
9921004
transition_connected: Callable[[ClientConnection], None],
@@ -998,6 +1010,10 @@ async def _do_ensure_connected[HandshakeMetadata](
9981010
last_error: Exception | None = None
9991011
attempt_count = 0
10001012
while rate_limiter.has_budget(client_id):
1013+
if (state := get_state()) in TerminalStates or state in ActiveStates:
1014+
logger.info(f"_do_ensure_connected stopping due to state={state}")
1015+
break
1016+
10011017
if attempt_count > 0:
10021018
logger.info(f"Retrying build handshake number {attempt_count} times")
10031019
attempt_count += 1
@@ -1051,40 +1067,40 @@ async def websocket_closed_callback() -> None:
10511067
handshake_deadline_ms = (
10521068
await get_current_time() + transport_options.handshake_timeout_ms
10531069
)
1054-
while True:
1055-
if await get_current_time() >= handshake_deadline_ms:
1056-
raise RiverException(
1057-
ERROR_HANDSHAKE,
1058-
"Handshake response timeout, closing connection",
1059-
)
1060-
try:
1061-
data = await ws.recv(decode=False)
1062-
except ConnectionClosed as e:
1063-
logger.debug(
1064-
"_do_ensure_connected: Connection closed during waiting "
1065-
"for handshake response",
1066-
exc_info=True,
1067-
)
1068-
raise RiverException(
1069-
ERROR_HANDSHAKE,
1070-
"Handshake failed, conn closed while waiting for response",
1071-
) from e
10721070

1073-
try:
1074-
response_msg = parse_transport_msg(data)
1075-
if isinstance(response_msg, str):
1076-
logger.debug(
1077-
"_do_ensure_connected: Ignoring transport message",
1078-
exc_info=True,
1079-
)
1080-
continue
1071+
if await get_current_time() >= handshake_deadline_ms:
1072+
raise RiverException(
1073+
ERROR_HANDSHAKE,
1074+
"Handshake response timeout, closing connection",
1075+
)
10811076

1082-
break
1083-
except InvalidMessageException as e:
1084-
raise RiverException(
1085-
ERROR_HANDSHAKE,
1086-
"Got invalid transport message, closing connection",
1087-
) from e
1077+
try:
1078+
data = await ws.recv(decode=False)
1079+
except ConnectionClosed as e:
1080+
logger.debug(
1081+
"_do_ensure_connected: Connection closed during waiting "
1082+
"for handshake response",
1083+
exc_info=True,
1084+
)
1085+
raise RiverException(
1086+
ERROR_HANDSHAKE,
1087+
"Handshake failed, conn closed while waiting for response",
1088+
) from e
1089+
1090+
try:
1091+
response_msg = parse_transport_msg(data)
1092+
except InvalidMessageException as e:
1093+
raise RiverException(
1094+
ERROR_HANDSHAKE,
1095+
"Got invalid transport message, closing connection",
1096+
) from e
1097+
1098+
if isinstance(response_msg, str):
1099+
raise RiverException(
1100+
ERROR_HANDSHAKE,
1101+
"Handshake failed, received a raw string message while waiting "
1102+
"for a handshake response",
1103+
)
10881104

10891105
try:
10901106
handshake_response = ControlMessageHandshakeResponse(
@@ -1105,6 +1121,7 @@ async def websocket_closed_callback() -> None:
11051121
}",
11061122
)
11071123
if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH:
1124+
# A session state mismatch is unrecoverable. Terminate immediately.
11081125
close_session(err)
11091126

11101127
raise err

0 commit comments

Comments
 (0)