2727from websockets .legacy .protocol import WebSocketCommonProtocol
2828
2929from replit_river .common_session import (
30+ SendMessage ,
3031 SessionState ,
3132 TerminalStates ,
32- buffered_message_sender ,
33- check_to_close_session ,
34- setup_heartbeat ,
3533)
3634from replit_river .error_schema import (
3735 ERROR_CODE_CANCEL ,
@@ -183,7 +181,7 @@ def increment_and_get_heartbeat_misses() -> int:
183181 return self ._heartbeat_misses
184182
185183 self ._task_manager .create_task (
186- setup_heartbeat (
184+ _setup_heartbeat (
187185 self .session_id ,
188186 self ._transport_options .heartbeat_ms ,
189187 self ._transport_options .heartbeats_until_dead ,
@@ -195,7 +193,7 @@ def increment_and_get_heartbeat_misses() -> int:
195193 )
196194 )
197195 self ._task_manager .create_task (
198- check_to_close_session (
196+ _check_to_close_session (
199197 self ._transport_id ,
200198 self ._transport_options .close_session_check_interval_ms ,
201199 lambda : self ._state ,
@@ -227,7 +225,7 @@ def get_ws() -> WebSocketCommonProtocol | ClientConnection | None:
227225 return None
228226
229227 self ._task_manager .create_task (
230- buffered_message_sender (
228+ _buffered_message_sender (
231229 self ._connection_condition ,
232230 self ._message_enqueued ,
233231 get_ws = get_ws ,
@@ -930,6 +928,137 @@ async def send_close_stream(
930928 )
931929
932930
931+ async def _check_to_close_session (
932+ transport_id : str ,
933+ close_session_check_interval_ms : float ,
934+ get_state : Callable [[], SessionState ],
935+ get_current_time : Callable [[], Awaitable [float ]],
936+ get_close_session_after_time_secs : Callable [[], float | None ],
937+ do_close : Callable [[], Awaitable [None ]],
938+ ) -> None :
939+ our_task = asyncio .current_task ()
940+ while our_task and not our_task .cancelling () and not our_task .cancelled ():
941+ await asyncio .sleep (close_session_check_interval_ms / 1000 )
942+ if get_state () in TerminalStates :
943+ # already closing
944+ return
945+ # calculate the value now before comparing it so that there are no
946+ # await points between the check and the comparison to avoid a TOCTOU
947+ # race.
948+ current_time = await get_current_time ()
949+ close_session_after_time_secs = get_close_session_after_time_secs ()
950+ if not close_session_after_time_secs :
951+ continue
952+ if current_time > close_session_after_time_secs :
953+ logger .info ("Grace period ended for %s, closing session" , transport_id )
954+ await do_close ()
955+ return
956+
957+
958+ async def _buffered_message_sender (
959+ connection_condition : asyncio .Condition ,
960+ message_enqueued : asyncio .Semaphore ,
961+ get_ws : Callable [[], WebSocketCommonProtocol | ClientConnection | None ],
962+ websocket_closed_callback : Callable [[], Coroutine [Any , Any , None ]],
963+ get_next_pending : Callable [[], TransportMessage | None ],
964+ commit : Callable [[TransportMessage ], None ],
965+ ) -> None :
966+ while True :
967+ await message_enqueued .acquire ()
968+ while (ws := get_ws ()) is None :
969+ # Block until we have a handle
970+ logger .debug (
971+ "buffered_message_sender: Waiting until ws is connected (condition=%r)" ,
972+ connection_condition ,
973+ )
974+ async with connection_condition :
975+ await connection_condition .wait ()
976+ if msg := get_next_pending ():
977+ logger .debug (
978+ "buffered_message_sender: Dequeued %r to send over %r" ,
979+ msg ,
980+ ws ,
981+ )
982+ try :
983+ await send_transport_message (msg , ws , websocket_closed_callback )
984+ commit (msg )
985+ except WebsocketClosedException as e :
986+ logger .debug (
987+ "Connection closed while sending message %r, waiting for "
988+ "retry from buffer" ,
989+ type (e ),
990+ exc_info = e ,
991+ )
992+ message_enqueued .release ()
993+ break
994+ except FailedSendingMessageException :
995+ logger .error (
996+ "Failed sending message, waiting for retry from buffer" ,
997+ exc_info = True ,
998+ )
999+ message_enqueued .release ()
1000+ break
1001+ except Exception :
1002+ logger .exception ("Error attempting to send buffered messages" )
1003+ message_enqueued .release ()
1004+ break
1005+
1006+
1007+ async def _setup_heartbeat (
1008+ session_id : str ,
1009+ heartbeat_ms : float ,
1010+ heartbeats_until_dead : int ,
1011+ get_state : Callable [[], SessionState ],
1012+ get_closing_grace_period : Callable [[], float | None ],
1013+ close_websocket : Callable [[], Awaitable [None ]],
1014+ send_message : SendMessage ,
1015+ increment_and_get_heartbeat_misses : Callable [[], int ],
1016+ ) -> None :
1017+ while True :
1018+ await asyncio .sleep (heartbeat_ms / 1000 )
1019+ state = get_state ()
1020+ if state == SessionState .CONNECTING :
1021+ logger .debug ("Websocket is not connected, not sending heartbeat" )
1022+ continue
1023+ if state in TerminalStates :
1024+ logger .debug (
1025+ "Session is closed, no need to send heartbeat, state : "
1026+ "%r close_session_after_this: %r" ,
1027+ {state },
1028+ {get_closing_grace_period ()},
1029+ )
1030+ # session is closing / closed, no need to send heartbeat anymore
1031+ return
1032+ try :
1033+ await send_message (
1034+ stream_id = "heartbeat" ,
1035+ # TODO: make this a message class
1036+ # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
1037+ payload = {
1038+ "type" : "ACK" ,
1039+ "ack" : 0 ,
1040+ },
1041+ control_flags = ACK_BIT ,
1042+ procedure_name = None ,
1043+ service_name = None ,
1044+ span = None ,
1045+ )
1046+
1047+ if increment_and_get_heartbeat_misses () > heartbeats_until_dead :
1048+ if get_closing_grace_period () is not None :
1049+ # already in grace period, no need to set again
1050+ continue
1051+ logger .info (
1052+ "%r closing websocket because of heartbeat misses" ,
1053+ session_id ,
1054+ )
1055+ await close_websocket ()
1056+ continue
1057+ except FailedSendingMessageException :
1058+ # this is expected during websocket closed period
1059+ continue
1060+
1061+
9331062async def _serve (
9341063 transport_id : str ,
9351064 get_state : Callable [[], SessionState ],
0 commit comments