Skip to content

Commit 4ed6178

Browse files
Exploratory, just transition_no_connection from close checker
1 parent 1255cbd commit 4ed6178

File tree

1 file changed

+43
-30
lines changed

1 file changed

+43
-30
lines changed

src/replit_river/v2/session.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,32 @@ def __init__(
215215
# Terminating
216216
self._terminating_task = None
217217

218-
self._start_serve_responses()
219-
self._start_close_session_checker()
218+
async def transition_connecting() -> None:
219+
if self._state in TerminalStates:
220+
return
221+
self._state = SessionState.CONNECTING
222+
self._wait_for_connected.clear()
223+
await self.ensure_connected()
224+
225+
async def transition_no_connection() -> None:
226+
if self._state in TerminalStates:
227+
return
228+
self._state = SessionState.NO_CONNECTION
229+
if self._ws:
230+
self._task_manager.create_task(self._ws.close())
231+
self._ws = None
232+
233+
if self._retry_connection_callback:
234+
self._task_manager.create_task(self._retry_connection_callback())
235+
236+
await self._begin_close_session_countdown()
237+
238+
self._start_serve_responses(
239+
transition_no_connection=transition_no_connection,
240+
)
241+
self._start_close_session_checker(
242+
transition_no_connection=transition_no_connection,
243+
)
220244
self._start_buffered_message_sender()
221245

222246
async def ensure_connected(self) -> None:
@@ -455,7 +479,9 @@ async def close(self) -> None:
455479
# This will get us GC'd, so this should be the last thing.
456480
await self._close_session_callback(self)
457481

458-
def _start_buffered_message_sender(self) -> None:
482+
def _start_buffered_message_sender(
483+
self,
484+
) -> None:
459485
def commit(msg: TransportMessage) -> None:
460486
pending = self._send_buffer.popleft()
461487
if msg.seq != pending.seq:
@@ -504,42 +530,29 @@ async def block_until_message_available() -> None:
504530
)
505531
)
506532

507-
def _start_close_session_checker(self) -> None:
508-
def transition_connecting() -> None:
509-
if self._state in TerminalStates:
510-
return
511-
self._state = SessionState.CONNECTING
512-
self._wait_for_connected.clear()
513-
533+
def _start_close_session_checker(
534+
self,
535+
transition_no_connection: Callable[[], Awaitable[None]],
536+
) -> None:
514537
self._task_manager.create_task(
515538
_check_to_close_session(
516-
self._transport_options.close_session_check_interval_ms,
517-
lambda: self._state,
518-
lambda: self._ws,
519-
transition_connecting=transition_connecting,
539+
close_session_check_interval_ms=self._transport_options.close_session_check_interval_ms,
540+
get_state=lambda: self._state,
541+
get_ws=lambda: self._ws,
542+
transition_no_connection=transition_no_connection,
520543
)
521544
)
522545

523-
def _start_serve_responses(self) -> None:
546+
def _start_serve_responses(
547+
self,
548+
transition_no_connection: Callable[[], Awaitable[None]],
549+
) -> None:
524550
def transition_connecting() -> None:
525551
if self._state in TerminalStates:
526552
return
527553
self._state = SessionState.CONNECTING
528554
self._wait_for_connected.clear()
529555

530-
async def transition_no_connection() -> None:
531-
if self._state in TerminalStates:
532-
return
533-
self._state = SessionState.NO_CONNECTION
534-
if self._ws:
535-
self._task_manager.create_task(self._ws.close())
536-
self._ws = None
537-
538-
if self._retry_connection_callback:
539-
self._task_manager.create_task(self._retry_connection_callback())
540-
541-
await self._begin_close_session_countdown()
542-
543556
def assert_incoming_seq_bookkeeping(
544557
msg_from: str,
545558
msg_seq: int,
@@ -939,15 +952,15 @@ async def _check_to_close_session(
939952
close_session_check_interval_ms: float,
940953
get_state: Callable[[], SessionState],
941954
get_ws: Callable[[], ClientConnection | None],
942-
transition_connecting: Callable[[], None],
955+
transition_no_connection: Callable[[], Awaitable[None]],
943956
) -> None:
944957
while get_state() not in TerminalStates:
945958
logger.debug("_check_to_close_session: Checking")
946959
await asyncio.sleep(close_session_check_interval_ms / 1000)
947960

948961
if (ws := get_ws()) and ws.protocol.state is CLOSED:
949962
logger.info("Websocket is closed, transitioning to connecting")
950-
transition_connecting()
963+
transition_no_connection()
951964

952965

953966
async def _do_ensure_connected[HandshakeMetadata](

0 commit comments

Comments
 (0)