2626from websockets .exceptions import ConnectionClosed , ConnectionClosedOK
2727
2828from replit_river .common_session import (
29+ ConnectingStates ,
2930 SendMessage ,
3031 SessionState ,
3132 TerminalStates ,
@@ -520,14 +521,19 @@ def get_ws() -> ClientConnection | None:
520521 return self ._ws_unwrapped
521522 return None
522523
524+ async def block_until_connected () -> None :
525+ async with self ._connection_condition :
526+ await self ._connection_condition .wait ()
527+
523528 self ._task_manager .create_task (
524529 _buffered_message_sender (
525- self . _connection_condition ,
526- self ._message_enqueued ,
530+ block_until_connected = block_until_connected ,
531+ message_enqueued = self ._message_enqueued ,
527532 get_ws = get_ws ,
528533 websocket_closed_callback = self ._begin_close_session_countdown ,
529534 get_next_pending = get_next_pending ,
530535 commit = commit ,
536+ get_state = lambda : self ._state ,
531537 )
532538 )
533539
@@ -565,8 +571,13 @@ def increment_and_get_heartbeat_misses() -> int:
565571 self ._heartbeat_misses += 1
566572 return self ._heartbeat_misses
567573
574+ async def block_until_connected () -> None :
575+ async with self ._connection_condition :
576+ await self ._connection_condition .wait ()
577+
568578 self ._task_manager .create_task (
569579 _setup_heartbeat (
580+ block_until_connected ,
570581 self .session_id ,
571582 self ._transport_options .heartbeat_ms ,
572583 self ._transport_options .heartbeats_until_dead ,
@@ -628,9 +639,15 @@ def assert_incoming_seq_bookkeeping(
628639 def close_stream (stream_id : str ) -> None :
629640 del self ._streams [stream_id ]
630641
642+ async def block_until_connected () -> None :
643+ async with self ._connection_condition :
644+ await self ._connection_condition .wait ()
645+
646+
631647 self ._task_manager .create_task (
632648 _serve (
633- self ._transport_id ,
649+ block_until_connected = block_until_connected ,
650+ transport_id = self ._transport_id ,
634651 get_state = lambda : self ._state ,
635652 get_ws = lambda : self ._ws_unwrapped ,
636653 transition_connecting = transition_connecting ,
@@ -976,23 +993,32 @@ async def _check_to_close_session(
976993
977994
978995async def _buffered_message_sender (
979- connection_condition : asyncio . Condition ,
996+ block_until_connected : Callable [[], Awaitable [ None ]] ,
980997 message_enqueued : asyncio .Semaphore ,
981998 get_ws : Callable [[], ClientConnection | None ],
982999 websocket_closed_callback : Callable [[], Coroutine [Any , Any , None ]],
9831000 get_next_pending : Callable [[], TransportMessage | None ],
9841001 commit : Callable [[TransportMessage ], None ],
1002+ get_state : Callable [[], SessionState ],
9851003) -> None :
986- while True :
1004+ our_task = asyncio .current_task ()
1005+ while our_task and not our_task .cancelling () and not our_task .cancelled ():
9871006 await message_enqueued .acquire ()
9881007 while (ws := get_ws ()) is None :
9891008 # Block until we have a handle
9901009 logger .debug (
991- "buffered_message_sender: Waiting until ws is connected (condition=%r)" ,
992- connection_condition ,
1010+ "buffered_message_sender: Waiting until ws is connected" ,
9931011 )
994- async with connection_condition :
995- await connection_condition .wait ()
1012+ await block_until_connected ()
1013+
1014+ if get_state () in TerminalStates :
1015+ logger .debug ("We're going away!" )
1016+ return
1017+
1018+ if not ws :
1019+ logger .debug ("ws is not connected, loop" )
1020+ continue
1021+
9961022 if msg := get_next_pending ():
9971023 logger .debug (
9981024 "buffered_message_sender: Dequeued %r to send over %r" ,
@@ -1025,6 +1051,7 @@ async def _buffered_message_sender(
10251051
10261052
10271053async def _setup_heartbeat (
1054+ block_until_connected : Callable [[], Awaitable [None ]],
10281055 session_id : str ,
10291056 heartbeat_ms : float ,
10301057 heartbeats_until_dead : int ,
@@ -1035,11 +1062,8 @@ async def _setup_heartbeat(
10351062 increment_and_get_heartbeat_misses : Callable [[], int ],
10361063) -> None :
10371064 while True :
1038- await asyncio .sleep (heartbeat_ms / 1000 )
1039- state = get_state ()
1040- if state == SessionState .CONNECTING :
1041- logger .debug ("Websocket is not connected, not sending heartbeat" )
1042- continue
1065+ while (state := get_state ()) in ConnectingStates :
1066+ await block_until_connected ()
10431067 if state in TerminalStates :
10441068 logger .debug (
10451069 "Session is closed, no need to send heartbeat, state : "
@@ -1048,7 +1072,13 @@ async def _setup_heartbeat(
10481072 {get_closing_grace_period ()},
10491073 )
10501074 # session is closing / closed, no need to send heartbeat anymore
1051- return
1075+ break
1076+
1077+ await asyncio .sleep (heartbeat_ms / 1000 )
1078+ state = get_state ()
1079+ if state == SessionState .CONNECTING :
1080+ logger .debug ("Websocket is not connected, not sending heartbeat" )
1081+ continue
10521082 try :
10531083 await send_message (
10541084 stream_id = "heartbeat" ,
@@ -1080,6 +1110,7 @@ async def _setup_heartbeat(
10801110
10811111
10821112async def _serve (
1113+ block_until_connected : Callable [[], Awaitable [None ]],
10831114 transport_id : str ,
10841115 get_state : Callable [[], SessionState ],
10851116 get_ws : Callable [[], ClientConnection | None ],
@@ -1103,7 +1134,7 @@ async def _serve(
11031134 idx += 1
11041135 while (ws := get_ws ()) is None or get_state () == SessionState .CONNECTING :
11051136 logging .debug ("_handle_messages_from_ws spinning while connecting" )
1106- await asyncio . sleep ( 1 )
1137+ await block_until_connected ( )
11071138 logger .debug (
11081139 "%s start handling messages from ws %s" ,
11091140 "client" ,
0 commit comments