Skip to content

Commit 253402c

Browse files
Colocating functions from common_session
1 parent 1ba7a9d commit 253402c

File tree

1 file changed

+135
-6
lines changed

1 file changed

+135
-6
lines changed

src/replit_river/v2/session.py

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,9 @@
2727
from websockets.legacy.protocol import WebSocketCommonProtocol
2828

2929
from replit_river.common_session import (
30+
SendMessage,
3031
SessionState,
3132
TerminalStates,
32-
buffered_message_sender,
33-
check_to_close_session,
34-
setup_heartbeat,
3533
)
3634
from replit_river.error_schema import (
3735
ERROR_CODE_CANCEL,
@@ -183,7 +181,7 @@ def increment_and_get_heartbeat_misses() -> int:
183181
return self._heartbeat_misses
184182

185183
self._task_manager.create_task(
186-
setup_heartbeat(
184+
_setup_heartbeat(
187185
self.session_id,
188186
self._transport_options.heartbeat_ms,
189187
self._transport_options.heartbeats_until_dead,
@@ -195,7 +193,7 @@ def increment_and_get_heartbeat_misses() -> int:
195193
)
196194
)
197195
self._task_manager.create_task(
198-
check_to_close_session(
196+
_check_to_close_session(
199197
self._transport_id,
200198
self._transport_options.close_session_check_interval_ms,
201199
lambda: self._state,
@@ -227,7 +225,7 @@ def get_ws() -> WebSocketCommonProtocol | ClientConnection | None:
227225
return None
228226

229227
self._task_manager.create_task(
230-
buffered_message_sender(
228+
_buffered_message_sender(
231229
self._connection_condition,
232230
self._message_enqueued,
233231
get_ws=get_ws,
@@ -930,6 +928,137 @@ async def send_close_stream(
930928
)
931929

932930

931+
async def _check_to_close_session(
932+
transport_id: str,
933+
close_session_check_interval_ms: float,
934+
get_state: Callable[[], SessionState],
935+
get_current_time: Callable[[], Awaitable[float]],
936+
get_close_session_after_time_secs: Callable[[], float | None],
937+
do_close: Callable[[], Awaitable[None]],
938+
) -> None:
939+
our_task = asyncio.current_task()
940+
while our_task and not our_task.cancelling() and not our_task.cancelled():
941+
await asyncio.sleep(close_session_check_interval_ms / 1000)
942+
if get_state() in TerminalStates:
943+
# already closing
944+
return
945+
# calculate the value now before comparing it so that there are no
946+
# await points between the check and the comparison to avoid a TOCTOU
947+
# race.
948+
current_time = await get_current_time()
949+
close_session_after_time_secs = get_close_session_after_time_secs()
950+
if not close_session_after_time_secs:
951+
continue
952+
if current_time > close_session_after_time_secs:
953+
logger.info("Grace period ended for %s, closing session", transport_id)
954+
await do_close()
955+
return
956+
957+
958+
async def _buffered_message_sender(
959+
connection_condition: asyncio.Condition,
960+
message_enqueued: asyncio.Semaphore,
961+
get_ws: Callable[[], WebSocketCommonProtocol | ClientConnection | None],
962+
websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]],
963+
get_next_pending: Callable[[], TransportMessage | None],
964+
commit: Callable[[TransportMessage], None],
965+
) -> None:
966+
while True:
967+
await message_enqueued.acquire()
968+
while (ws := get_ws()) is None:
969+
# Block until we have a handle
970+
logger.debug(
971+
"buffered_message_sender: Waiting until ws is connected (condition=%r)",
972+
connection_condition,
973+
)
974+
async with connection_condition:
975+
await connection_condition.wait()
976+
if msg := get_next_pending():
977+
logger.debug(
978+
"buffered_message_sender: Dequeued %r to send over %r",
979+
msg,
980+
ws,
981+
)
982+
try:
983+
await send_transport_message(msg, ws, websocket_closed_callback)
984+
commit(msg)
985+
except WebsocketClosedException as e:
986+
logger.debug(
987+
"Connection closed while sending message %r, waiting for "
988+
"retry from buffer",
989+
type(e),
990+
exc_info=e,
991+
)
992+
message_enqueued.release()
993+
break
994+
except FailedSendingMessageException:
995+
logger.error(
996+
"Failed sending message, waiting for retry from buffer",
997+
exc_info=True,
998+
)
999+
message_enqueued.release()
1000+
break
1001+
except Exception:
1002+
logger.exception("Error attempting to send buffered messages")
1003+
message_enqueued.release()
1004+
break
1005+
1006+
1007+
async def _setup_heartbeat(
1008+
session_id: str,
1009+
heartbeat_ms: float,
1010+
heartbeats_until_dead: int,
1011+
get_state: Callable[[], SessionState],
1012+
get_closing_grace_period: Callable[[], float | None],
1013+
close_websocket: Callable[[], Awaitable[None]],
1014+
send_message: SendMessage,
1015+
increment_and_get_heartbeat_misses: Callable[[], int],
1016+
) -> None:
1017+
while True:
1018+
await asyncio.sleep(heartbeat_ms / 1000)
1019+
state = get_state()
1020+
if state == SessionState.CONNECTING:
1021+
logger.debug("Websocket is not connected, not sending heartbeat")
1022+
continue
1023+
if state in TerminalStates:
1024+
logger.debug(
1025+
"Session is closed, no need to send heartbeat, state : "
1026+
"%r close_session_after_this: %r",
1027+
{state},
1028+
{get_closing_grace_period()},
1029+
)
1030+
# session is closing / closed, no need to send heartbeat anymore
1031+
return
1032+
try:
1033+
await send_message(
1034+
stream_id="heartbeat",
1035+
# TODO: make this a message class
1036+
# https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
1037+
payload={
1038+
"type": "ACK",
1039+
"ack": 0,
1040+
},
1041+
control_flags=ACK_BIT,
1042+
procedure_name=None,
1043+
service_name=None,
1044+
span=None,
1045+
)
1046+
1047+
if increment_and_get_heartbeat_misses() > heartbeats_until_dead:
1048+
if get_closing_grace_period() is not None:
1049+
# already in grace period, no need to set again
1050+
continue
1051+
logger.info(
1052+
"%r closing websocket because of heartbeat misses",
1053+
session_id,
1054+
)
1055+
await close_websocket()
1056+
continue
1057+
except FailedSendingMessageException:
1058+
# this is expected during websocket closed period
1059+
continue
1060+
1061+
9331062
async def _serve(
9341063
transport_id: str,
9351064
get_state: Callable[[], SessionState],

0 commit comments

Comments
 (0)