Skip to content

Commit 37d2e49

Browse files
Migrating all background tasks to just use block_until_connected
1 parent 434a241 commit 37d2e49

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

src/replit_river/common_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class SessionState(enum.Enum):
4848
CLOSED = 4
4949

5050

51+
ConnectingStates = set([SessionState.PENDING, SessionState.CONNECTING])
5152
TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED])
5253

5354

src/replit_river/v2/session.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from websockets.exceptions import ConnectionClosed, ConnectionClosedOK
2727

2828
from replit_river.common_session import (
29+
ConnectingStates,
2930
SendMessage,
3031
SessionState,
3132
TerminalStates,
@@ -520,14 +521,19 @@ def get_ws() -> ClientConnection | None:
520521
return self._ws_unwrapped
521522
return None
522523

524+
async def block_until_connected() -> None:
525+
async with self._connection_condition:
526+
await self._connection_condition.wait()
527+
523528
self._task_manager.create_task(
524529
_buffered_message_sender(
525-
self._connection_condition,
526-
self._message_enqueued,
530+
block_until_connected=block_until_connected,
531+
message_enqueued=self._message_enqueued,
527532
get_ws=get_ws,
528533
websocket_closed_callback=self._begin_close_session_countdown,
529534
get_next_pending=get_next_pending,
530535
commit=commit,
536+
get_state=lambda: self._state,
531537
)
532538
)
533539

@@ -565,8 +571,13 @@ def increment_and_get_heartbeat_misses() -> int:
565571
self._heartbeat_misses += 1
566572
return self._heartbeat_misses
567573

574+
async def block_until_connected() -> None:
575+
async with self._connection_condition:
576+
await self._connection_condition.wait()
577+
568578
self._task_manager.create_task(
569579
_setup_heartbeat(
580+
block_until_connected,
570581
self.session_id,
571582
self._transport_options.heartbeat_ms,
572583
self._transport_options.heartbeats_until_dead,
@@ -628,9 +639,15 @@ def assert_incoming_seq_bookkeeping(
628639
def close_stream(stream_id: str) -> None:
629640
del self._streams[stream_id]
630641

642+
async def block_until_connected() -> None:
643+
async with self._connection_condition:
644+
await self._connection_condition.wait()
645+
646+
631647
self._task_manager.create_task(
632648
_serve(
633-
self._transport_id,
649+
block_until_connected=block_until_connected,
650+
transport_id=self._transport_id,
634651
get_state=lambda: self._state,
635652
get_ws=lambda: self._ws_unwrapped,
636653
transition_connecting=transition_connecting,
@@ -976,23 +993,32 @@ async def _check_to_close_session(
976993

977994

978995
async def _buffered_message_sender(
979-
connection_condition: asyncio.Condition,
996+
block_until_connected: Callable[[], Awaitable[None]],
980997
message_enqueued: asyncio.Semaphore,
981998
get_ws: Callable[[], ClientConnection | None],
982999
websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]],
9831000
get_next_pending: Callable[[], TransportMessage | None],
9841001
commit: Callable[[TransportMessage], None],
1002+
get_state: Callable[[], SessionState],
9851003
) -> None:
986-
while True:
1004+
our_task = asyncio.current_task()
1005+
while our_task and not our_task.cancelling() and not our_task.cancelled():
9871006
await message_enqueued.acquire()
9881007
while (ws := get_ws()) is None:
9891008
# Block until we have a handle
9901009
logger.debug(
991-
"buffered_message_sender: Waiting until ws is connected (condition=%r)",
992-
connection_condition,
1010+
"buffered_message_sender: Waiting until ws is connected",
9931011
)
994-
async with connection_condition:
995-
await connection_condition.wait()
1012+
await block_until_connected()
1013+
1014+
if get_state() in TerminalStates:
1015+
logger.debug("We're going away!")
1016+
return
1017+
1018+
if not ws:
1019+
logger.debug("ws is not connected, loop")
1020+
continue
1021+
9961022
if msg := get_next_pending():
9971023
logger.debug(
9981024
"buffered_message_sender: Dequeued %r to send over %r",
@@ -1025,6 +1051,7 @@ async def _buffered_message_sender(
10251051

10261052

10271053
async def _setup_heartbeat(
1054+
block_until_connected: Callable[[], Awaitable[None]],
10281055
session_id: str,
10291056
heartbeat_ms: float,
10301057
heartbeats_until_dead: int,
@@ -1035,11 +1062,8 @@ async def _setup_heartbeat(
10351062
increment_and_get_heartbeat_misses: Callable[[], int],
10361063
) -> None:
10371064
while True:
1038-
await asyncio.sleep(heartbeat_ms / 1000)
1039-
state = get_state()
1040-
if state == SessionState.CONNECTING:
1041-
logger.debug("Websocket is not connected, not sending heartbeat")
1042-
continue
1065+
while (state := get_state()) in ConnectingStates:
1066+
await block_until_connected()
10431067
if state in TerminalStates:
10441068
logger.debug(
10451069
"Session is closed, no need to send heartbeat, state : "
@@ -1048,7 +1072,13 @@ async def _setup_heartbeat(
10481072
{get_closing_grace_period()},
10491073
)
10501074
# session is closing / closed, no need to send heartbeat anymore
1051-
return
1075+
break
1076+
1077+
await asyncio.sleep(heartbeat_ms / 1000)
1078+
state = get_state()
1079+
if state == SessionState.CONNECTING:
1080+
logger.debug("Websocket is not connected, not sending heartbeat")
1081+
continue
10521082
try:
10531083
await send_message(
10541084
stream_id="heartbeat",
@@ -1080,6 +1110,7 @@ async def _setup_heartbeat(
10801110

10811111

10821112
async def _serve(
1113+
block_until_connected: Callable[[], Awaitable[None]],
10831114
transport_id: str,
10841115
get_state: Callable[[], SessionState],
10851116
get_ws: Callable[[], ClientConnection | None],
@@ -1103,7 +1134,7 @@ async def _serve(
11031134
idx += 1
11041135
while (ws := get_ws()) is None or get_state() == SessionState.CONNECTING:
11051136
logging.debug("_handle_messages_from_ws spinning while connecting")
1106-
await asyncio.sleep(1)
1137+
await block_until_connected()
11071138
logger.debug(
11081139
"%s start handling messages from ws %s",
11091140
"client",

0 commit comments

Comments
 (0)