Skip to content

Commit 805aedc

Browse files
Exploratory, just transition_no_connection from close checker
1 parent 1255cbd commit 805aedc

File tree

1 file changed

+36
-30
lines changed

1 file changed

+36
-30
lines changed

src/replit_river/v2/session.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,25 @@ def __init__(
215215
# Terminating
216216
self._terminating_task = None
217217

218-
self._start_serve_responses()
219-
self._start_close_session_checker()
218+
async def transition_no_connection() -> None:
219+
if self._state in TerminalStates:
220+
return
221+
self._state = SessionState.NO_CONNECTION
222+
if self._ws:
223+
self._task_manager.create_task(self._ws.close())
224+
self._ws = None
225+
226+
if self._retry_connection_callback:
227+
self._task_manager.create_task(self._retry_connection_callback())
228+
229+
await self._begin_close_session_countdown()
230+
231+
self._start_serve_responses(
232+
transition_no_connection=transition_no_connection,
233+
)
234+
self._start_close_session_checker(
235+
transition_no_connection=transition_no_connection,
236+
)
220237
self._start_buffered_message_sender()
221238

222239
async def ensure_connected(self) -> None:
@@ -455,7 +472,9 @@ async def close(self) -> None:
455472
# This will get us GC'd, so this should be the last thing.
456473
await self._close_session_callback(self)
457474

458-
def _start_buffered_message_sender(self) -> None:
475+
def _start_buffered_message_sender(
476+
self,
477+
) -> None:
459478
def commit(msg: TransportMessage) -> None:
460479
pending = self._send_buffer.popleft()
461480
if msg.seq != pending.seq:
@@ -504,42 +523,29 @@ async def block_until_message_available() -> None:
504523
)
505524
)
506525

507-
def _start_close_session_checker(self) -> None:
508-
def transition_connecting() -> None:
509-
if self._state in TerminalStates:
510-
return
511-
self._state = SessionState.CONNECTING
512-
self._wait_for_connected.clear()
513-
526+
def _start_close_session_checker(
527+
self,
528+
transition_no_connection: Callable[[], Awaitable[None]],
529+
) -> None:
514530
self._task_manager.create_task(
515531
_check_to_close_session(
516-
self._transport_options.close_session_check_interval_ms,
517-
lambda: self._state,
518-
lambda: self._ws,
519-
transition_connecting=transition_connecting,
532+
close_session_check_interval_ms=self._transport_options.close_session_check_interval_ms,
533+
get_state=lambda: self._state,
534+
get_ws=lambda: self._ws,
535+
transition_no_connection=transition_no_connection,
520536
)
521537
)
522538

523-
def _start_serve_responses(self) -> None:
539+
def _start_serve_responses(
540+
self,
541+
transition_no_connection: Callable[[], Awaitable[None]],
542+
) -> None:
524543
def transition_connecting() -> None:
525544
if self._state in TerminalStates:
526545
return
527546
self._state = SessionState.CONNECTING
528547
self._wait_for_connected.clear()
529548

530-
async def transition_no_connection() -> None:
531-
if self._state in TerminalStates:
532-
return
533-
self._state = SessionState.NO_CONNECTION
534-
if self._ws:
535-
self._task_manager.create_task(self._ws.close())
536-
self._ws = None
537-
538-
if self._retry_connection_callback:
539-
self._task_manager.create_task(self._retry_connection_callback())
540-
541-
await self._begin_close_session_countdown()
542-
543549
def assert_incoming_seq_bookkeeping(
544550
msg_from: str,
545551
msg_seq: int,
@@ -939,15 +945,15 @@ async def _check_to_close_session(
939945
close_session_check_interval_ms: float,
940946
get_state: Callable[[], SessionState],
941947
get_ws: Callable[[], ClientConnection | None],
942-
transition_connecting: Callable[[], None],
948+
transition_no_connection: Callable[[], Awaitable[None]],
943949
) -> None:
944950
while get_state() not in TerminalStates:
945951
logger.debug("_check_to_close_session: Checking")
946952
await asyncio.sleep(close_session_check_interval_ms / 1000)
947953

948954
if (ws := get_ws()) and ws.protocol.state is CLOSED:
949955
logger.info("Websocket is closed, transitioning to connecting")
950-
transition_connecting()
956+
await transition_no_connection()
951957

952958

953959
async def _do_ensure_connected[HandshakeMetadata](

0 commit comments

Comments
 (0)