Skip to content

Commit b50f9e6

Browse files
Ripping out all the heartbeat stuff in favor of server-directed signaling
1 parent 261945b commit b50f9e6

File tree

1 file changed

+17
-114
lines changed

1 file changed

+17
-114
lines changed

src/replit_river/v2/session.py

Lines changed: 17 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pydantic import ValidationError
2929
from websockets.asyncio.client import ClientConnection
3030
from websockets.exceptions import ConnectionClosed, ConnectionClosedOK
31+
from websockets.protocol import CLOSED
3132

3233
from replit_river.common_session import (
3334
ConnectingStates,
@@ -199,7 +200,6 @@ def __init__(
199200
# Terminating
200201
self._terminating_task = None
201202

202-
self._start_heartbeat()
203203
self._start_serve_responses()
204204
self._start_close_session_checker()
205205
self._start_buffered_message_sender()
@@ -497,65 +497,24 @@ async def block_until_message_available() -> None:
497497
)
498498

499499
def _start_close_session_checker(self) -> None:
500-
def do_close() -> None:
501-
# Avoid closing twice
502-
if self._terminating_task is None:
503-
# We can't just call self.close() directly because
504-
# we're inside a thread that will eventually be awaited
505-
# during the cleanup procedure.
506-
self._terminating_task = asyncio.create_task(self.close())
500+
def transition_connecting() -> None:
501+
if self._state in TerminalStates:
502+
return
503+
self._state = SessionState.CONNECTING
504+
self._wait_for_connected.clear()
507505

508506
self._task_manager.create_task(
509507
_check_to_close_session(
510508
self._transport_id,
511509
self._transport_options.close_session_check_interval_ms,
512510
lambda: self._state,
513-
self._get_current_time,
514-
lambda: self._close_session_after_time_secs,
515-
do_close=do_close,
516-
)
517-
)
518-
519-
def _start_heartbeat(self) -> None:
520-
async def close_websocket() -> None:
521-
logger.debug(
522-
"close_websocket called, _state=%r, _ws=%r",
523-
self._state,
524-
self._ws,
525-
)
526-
if self._ws:
527-
self._task_manager.create_task(self._ws.close())
528-
self._ws = None
529-
530-
if self._retry_connection_callback:
531-
self._task_manager.create_task(self._retry_connection_callback())
532-
else:
533-
self._state = SessionState.CLOSING
534-
535-
await self._begin_close_session_countdown()
536-
537-
def increment_and_get_heartbeat_misses() -> int:
538-
self._heartbeat_misses += 1
539-
return self._heartbeat_misses
540-
541-
async def block_until_connected() -> None:
542-
await self._wait_for_connected.wait()
543-
544-
self._task_manager.create_task(
545-
_setup_heartbeat(
546-
block_until_connected,
547-
self.session_id,
548-
self._transport_options.heartbeat_ms,
549-
self._transport_options.heartbeats_until_dead,
550-
lambda: self._state,
551-
lambda: self._close_session_after_time_secs,
552-
close_websocket=close_websocket,
553-
increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses,
511+
lambda: self._ws,
512+
transition_connecting=transition_connecting,
554513
)
555514
)
556515

557516
def _start_serve_responses(self) -> None:
558-
async def transition_connecting() -> None:
517+
def transition_connecting() -> None:
559518
if self._state in TerminalStates:
560519
return
561520
self._state = SessionState.CONNECTING
@@ -973,31 +932,16 @@ async def _check_to_close_session(
973932
transport_id: str,
974933
close_session_check_interval_ms: float,
975934
get_state: Callable[[], SessionState],
976-
get_current_time: Callable[[], Awaitable[float]],
977-
get_close_session_after_time_secs: Callable[[], float | None],
978-
do_close: Callable[[], None],
935+
get_ws: Callable[[], ClientConnection | None],
936+
transition_connecting: Callable[[], None],
979937
) -> None:
980-
our_task = asyncio.current_task()
981-
while our_task and not our_task.cancelling() and not our_task.cancelled():
938+
while get_state() not in TerminalStates:
982939
logger.debug("_check_to_close_session: Checking")
983940
await asyncio.sleep(close_session_check_interval_ms / 1000)
984-
if get_state() in TerminalStates:
985-
# already closing
986-
break
987-
# calculate the value now before comparing it so that there are no
988-
# await points between the check and the comparison to avoid a TOCTOU
989-
# race.
990-
current_time = await get_current_time()
991-
close_session_after_time_secs = get_close_session_after_time_secs()
992-
if not close_session_after_time_secs:
993-
logger.debug(
994-
f"_check_to_close_session: Not reached: {close_session_after_time_secs}"
995-
)
996-
continue
997-
if current_time > close_session_after_time_secs:
941+
942+
if not (ws := get_ws()) or ws.protocol.state is CLOSED:
998943
logger.info("Grace period ended for %s, closing session", transport_id)
999-
do_close()
1000-
our_task.cancel()
944+
transition_connecting()
1001945

1002946

1003947
async def _do_ensure_connected[HandshakeMetadata](
@@ -1160,53 +1104,12 @@ async def websocket_closed_callback() -> None:
11601104
return None
11611105

11621106

1163-
async def _setup_heartbeat(
1164-
block_until_connected: Callable[[], Awaitable[None]],
1165-
session_id: str,
1166-
heartbeat_ms: float,
1167-
heartbeats_until_dead: int,
1168-
get_state: Callable[[], SessionState],
1169-
get_closing_grace_period: Callable[[], float | None],
1170-
close_websocket: Callable[[], Awaitable[None]],
1171-
increment_and_get_heartbeat_misses: Callable[[], int],
1172-
) -> None:
1173-
while True:
1174-
while (state := get_state()) in ConnectingStates:
1175-
logger.debug(
1176-
"Heartbeat: block_until_connected: %r",
1177-
state,
1178-
)
1179-
await block_until_connected()
1180-
1181-
if state in TerminalStates:
1182-
logger.debug(
1183-
"Session is closed, no need to send heartbeat, state : "
1184-
"%r close_session_after_this: %r",
1185-
state,
1186-
get_closing_grace_period(),
1187-
)
1188-
# session is closing / closed, no need to send heartbeat anymore
1189-
break
1190-
1191-
await asyncio.sleep(heartbeat_ms / 1000)
1192-
1193-
if increment_and_get_heartbeat_misses() > heartbeats_until_dead:
1194-
if get_closing_grace_period() is not None:
1195-
# already in grace period, no need to set again
1196-
continue
1197-
logger.info(
1198-
"%r closing websocket because of heartbeat misses",
1199-
session_id,
1200-
)
1201-
await close_websocket()
1202-
1203-
12041107
async def _serve(
12051108
block_until_connected: Callable[[], Awaitable[None]],
12061109
transport_id: str,
12071110
get_state: Callable[[], SessionState],
12081111
get_ws: Callable[[], ClientConnection | None],
1209-
transition_connecting: Callable[[], Awaitable[None]],
1112+
transition_connecting: Callable[[], None],
12101113
transition_no_connection: Callable[[], Awaitable[None]],
12111114
reset_session_close_countdown: Callable[[], None],
12121115
close_session: Callable[[], Awaitable[None]],
@@ -1262,7 +1165,7 @@ async def _serve(
12621165
try:
12631166
message = await ws.recv(decode=False)
12641167
except ConnectionClosed:
1265-
await transition_connecting()
1168+
transition_connecting()
12661169
continue
12671170
try:
12681171
msg = parse_transport_msg(message)

0 commit comments

Comments
 (0)