|
28 | 28 | from pydantic import ValidationError |
29 | 29 | from websockets.asyncio.client import ClientConnection |
30 | 30 | from websockets.exceptions import ConnectionClosed, ConnectionClosedOK |
| 31 | +from websockets.protocol import CLOSED |
31 | 32 |
|
32 | 33 | from replit_river.common_session import ( |
33 | 34 | ConnectingStates, |
@@ -199,7 +200,6 @@ def __init__( |
199 | 200 | # Terminating |
200 | 201 | self._terminating_task = None |
201 | 202 |
|
202 | | - self._start_heartbeat() |
203 | 203 | self._start_serve_responses() |
204 | 204 | self._start_close_session_checker() |
205 | 205 | self._start_buffered_message_sender() |
@@ -497,65 +497,24 @@ async def block_until_message_available() -> None: |
497 | 497 | ) |
498 | 498 |
|
499 | 499 | def _start_close_session_checker(self) -> None: |
500 | | - def do_close() -> None: |
501 | | - # Avoid closing twice |
502 | | - if self._terminating_task is None: |
503 | | - # We can't just call self.close() directly because |
504 | | - # we're inside a thread that will eventually be awaited |
505 | | - # during the cleanup procedure. |
506 | | - self._terminating_task = asyncio.create_task(self.close()) |
| 500 | + def transition_connecting() -> None: |
| 501 | + if self._state in TerminalStates: |
| 502 | + return |
| 503 | + self._state = SessionState.CONNECTING |
| 504 | + self._wait_for_connected.clear() |
507 | 505 |
|
508 | 506 | self._task_manager.create_task( |
509 | 507 | _check_to_close_session( |
510 | 508 | self._transport_id, |
511 | 509 | self._transport_options.close_session_check_interval_ms, |
512 | 510 | lambda: self._state, |
513 | | - self._get_current_time, |
514 | | - lambda: self._close_session_after_time_secs, |
515 | | - do_close=do_close, |
516 | | - ) |
517 | | - ) |
518 | | - |
519 | | - def _start_heartbeat(self) -> None: |
520 | | - async def close_websocket() -> None: |
521 | | - logger.debug( |
522 | | - "close_websocket called, _state=%r, _ws=%r", |
523 | | - self._state, |
524 | | - self._ws, |
525 | | - ) |
526 | | - if self._ws: |
527 | | - self._task_manager.create_task(self._ws.close()) |
528 | | - self._ws = None |
529 | | - |
530 | | - if self._retry_connection_callback: |
531 | | - self._task_manager.create_task(self._retry_connection_callback()) |
532 | | - else: |
533 | | - self._state = SessionState.CLOSING |
534 | | - |
535 | | - await self._begin_close_session_countdown() |
536 | | - |
537 | | - def increment_and_get_heartbeat_misses() -> int: |
538 | | - self._heartbeat_misses += 1 |
539 | | - return self._heartbeat_misses |
540 | | - |
541 | | - async def block_until_connected() -> None: |
542 | | - await self._wait_for_connected.wait() |
543 | | - |
544 | | - self._task_manager.create_task( |
545 | | - _setup_heartbeat( |
546 | | - block_until_connected, |
547 | | - self.session_id, |
548 | | - self._transport_options.heartbeat_ms, |
549 | | - self._transport_options.heartbeats_until_dead, |
550 | | - lambda: self._state, |
551 | | - lambda: self._close_session_after_time_secs, |
552 | | - close_websocket=close_websocket, |
553 | | - increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, |
| 511 | + lambda: self._ws, |
| 512 | + transition_connecting=transition_connecting, |
554 | 513 | ) |
555 | 514 | ) |
556 | 515 |
|
557 | 516 | def _start_serve_responses(self) -> None: |
558 | | - async def transition_connecting() -> None: |
| 517 | + def transition_connecting() -> None: |
559 | 518 | if self._state in TerminalStates: |
560 | 519 | return |
561 | 520 | self._state = SessionState.CONNECTING |
@@ -973,31 +932,16 @@ async def _check_to_close_session( |
973 | 932 | transport_id: str, |
974 | 933 | close_session_check_interval_ms: float, |
975 | 934 | get_state: Callable[[], SessionState], |
976 | | - get_current_time: Callable[[], Awaitable[float]], |
977 | | - get_close_session_after_time_secs: Callable[[], float | None], |
978 | | - do_close: Callable[[], None], |
| 935 | + get_ws: Callable[[], ClientConnection | None], |
| 936 | + transition_connecting: Callable[[], None], |
979 | 937 | ) -> None: |
980 | | - our_task = asyncio.current_task() |
981 | | - while our_task and not our_task.cancelling() and not our_task.cancelled(): |
| 938 | + while get_state() not in TerminalStates: |
982 | 939 | logger.debug("_check_to_close_session: Checking") |
983 | 940 | await asyncio.sleep(close_session_check_interval_ms / 1000) |
984 | | - if get_state() in TerminalStates: |
985 | | - # already closing |
986 | | - break |
987 | | - # calculate the value now before comparing it so that there are no |
988 | | - # await points between the check and the comparison to avoid a TOCTOU |
989 | | - # race. |
990 | | - current_time = await get_current_time() |
991 | | - close_session_after_time_secs = get_close_session_after_time_secs() |
992 | | - if not close_session_after_time_secs: |
993 | | - logger.debug( |
994 | | - f"_check_to_close_session: Not reached: {close_session_after_time_secs}" |
995 | | - ) |
996 | | - continue |
997 | | - if current_time > close_session_after_time_secs: |
| 941 | + |
| 942 | + if not (ws := get_ws()) or ws.protocol.state is CLOSED: |
998 | 943 | logger.info("Grace period ended for %s, closing session", transport_id) |
999 | | - do_close() |
1000 | | - our_task.cancel() |
| 944 | + transition_connecting() |
1001 | 945 |
|
1002 | 946 |
|
1003 | 947 | async def _do_ensure_connected[HandshakeMetadata]( |
@@ -1160,53 +1104,12 @@ async def websocket_closed_callback() -> None: |
1160 | 1104 | return None |
1161 | 1105 |
|
1162 | 1106 |
|
1163 | | -async def _setup_heartbeat( |
1164 | | - block_until_connected: Callable[[], Awaitable[None]], |
1165 | | - session_id: str, |
1166 | | - heartbeat_ms: float, |
1167 | | - heartbeats_until_dead: int, |
1168 | | - get_state: Callable[[], SessionState], |
1169 | | - get_closing_grace_period: Callable[[], float | None], |
1170 | | - close_websocket: Callable[[], Awaitable[None]], |
1171 | | - increment_and_get_heartbeat_misses: Callable[[], int], |
1172 | | -) -> None: |
1173 | | - while True: |
1174 | | - while (state := get_state()) in ConnectingStates: |
1175 | | - logger.debug( |
1176 | | - "Heartbeat: block_until_connected: %r", |
1177 | | - state, |
1178 | | - ) |
1179 | | - await block_until_connected() |
1180 | | - |
1181 | | - if state in TerminalStates: |
1182 | | - logger.debug( |
1183 | | - "Session is closed, no need to send heartbeat, state : " |
1184 | | - "%r close_session_after_this: %r", |
1185 | | - state, |
1186 | | - get_closing_grace_period(), |
1187 | | - ) |
1188 | | - # session is closing / closed, no need to send heartbeat anymore |
1189 | | - break |
1190 | | - |
1191 | | - await asyncio.sleep(heartbeat_ms / 1000) |
1192 | | - |
1193 | | - if increment_and_get_heartbeat_misses() > heartbeats_until_dead: |
1194 | | - if get_closing_grace_period() is not None: |
1195 | | - # already in grace period, no need to set again |
1196 | | - continue |
1197 | | - logger.info( |
1198 | | - "%r closing websocket because of heartbeat misses", |
1199 | | - session_id, |
1200 | | - ) |
1201 | | - await close_websocket() |
1202 | | - |
1203 | | - |
1204 | 1107 | async def _serve( |
1205 | 1108 | block_until_connected: Callable[[], Awaitable[None]], |
1206 | 1109 | transport_id: str, |
1207 | 1110 | get_state: Callable[[], SessionState], |
1208 | 1111 | get_ws: Callable[[], ClientConnection | None], |
1209 | | - transition_connecting: Callable[[], Awaitable[None]], |
| 1112 | + transition_connecting: Callable[[], None], |
1210 | 1113 | transition_no_connection: Callable[[], Awaitable[None]], |
1211 | 1114 | reset_session_close_countdown: Callable[[], None], |
1212 | 1115 | close_session: Callable[[], Awaitable[None]], |
@@ -1262,7 +1165,7 @@ async def _serve( |
1262 | 1165 | try: |
1263 | 1166 | message = await ws.recv(decode=False) |
1264 | 1167 | except ConnectionClosed: |
1265 | | - await transition_connecting() |
| 1168 | + transition_connecting() |
1266 | 1169 | continue |
1267 | 1170 | try: |
1268 | 1171 | msg = parse_transport_msg(message) |
|
0 commit comments