@@ -603,7 +603,6 @@ async def block_until_connected() -> None:
603603 lambda : self ._state ,
604604 lambda : self ._close_session_after_time_secs ,
605605 close_websocket = close_websocket ,
606- send_message = self .send_message ,
607606 increment_and_get_heartbeat_misses = increment_and_get_heartbeat_misses ,
608607 )
609608 )
@@ -662,6 +661,9 @@ async def block_until_connected() -> None:
662661 async with self ._connection_condition :
663662 await self ._connection_condition .wait ()
664663
664+ def received_message (message : TransportMessage ) -> None :
665+ pass
666+
665667 self ._task_manager .create_task (
666668 _serve (
667669 block_until_connected = block_until_connected ,
@@ -675,6 +677,8 @@ async def block_until_connected() -> None:
675677 assert_incoming_seq_bookkeeping = assert_incoming_seq_bookkeeping ,
676678 get_stream = lambda stream_id : self ._streams .get (stream_id ),
677679 close_stream = close_stream ,
680+ received_message = received_message ,
681+ send_message = self .send_message ,
678682 )
679683 )
680684
@@ -1073,7 +1077,6 @@ async def _setup_heartbeat(
10731077 get_state : Callable [[], SessionState ],
10741078 get_closing_grace_period : Callable [[], float | None ],
10751079 close_websocket : Callable [[], Awaitable [None ]],
1076- send_message : SendMessage ,
10771080 increment_and_get_heartbeat_misses : Callable [[], int ],
10781081) -> None :
10791082 while True :
@@ -1092,36 +1095,18 @@ async def _setup_heartbeat(
10921095 await asyncio .sleep (heartbeat_ms / 1000 )
10931096 state = get_state ()
10941097 if state in ConnectingStates :
1095- logger .debug ("Websocket is not connected, not sending heartbeat" )
1098+ logger .debug ("Websocket is not connected, don't expect heartbeat" )
10961099 continue
1097- try :
1098- await send_message (
1099- stream_id = "heartbeat" ,
1100- # TODO: make this a message class
1101- # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
1102- payload = {
1103- "type" : "ACK" ,
1104- "ack" : 0 ,
1105- },
1106- control_flags = ACK_BIT ,
1107- procedure_name = None ,
1108- service_name = None ,
1109- span = None ,
1110- )
11111100
1112- if increment_and_get_heartbeat_misses () > heartbeats_until_dead :
1113- if get_closing_grace_period () is not None :
1114- # already in grace period, no need to set again
1115- continue
1116- logger .info (
1117- "%r closing websocket because of heartbeat misses" ,
1118- session_id ,
1119- )
1120- await close_websocket ()
1101+ if increment_and_get_heartbeat_misses () > heartbeats_until_dead :
1102+ if get_closing_grace_period () is not None :
1103+ # already in grace period, no need to set again
11211104 continue
1122- except FailedSendingMessageException :
1123- # this is expected during websocket closed period
1124- continue
1105+ logger .info (
1106+ "%r closing websocket because of heartbeat misses" ,
1107+ session_id ,
1108+ )
1109+ await close_websocket ()
11251110
11261111
11271112async def _serve (
@@ -1138,6 +1123,8 @@ async def _serve(
11381123 ], # noqa: E501
11391124 get_stream : Callable [[str ], Channel [Any ] | None ],
11401125 close_stream : Callable [[str ], None ],
1126+ received_message : Callable [[TransportMessage ], None ],
1127+ send_message : SendMessage ,
11411128) -> None :
11421129 """Serve messages from the websocket."""
11431130 reset_session_close_countdown ()
@@ -1169,6 +1156,8 @@ async def _serve(
11691156 msg ,
11701157 )
11711158
1159+ received_message (msg )
1160+
11721161 if msg .controlFlags & STREAM_OPEN_BIT != 0 :
11731162 raise InvalidMessageException (
11741163 "Client should not receive stream open bit"
@@ -1194,6 +1183,19 @@ async def _serve(
11941183
11951184 # Shortcut to avoid processing ack packets
11961185 if msg .controlFlags & ACK_BIT != 0 :
1186+ await send_message (
1187+ stream_id = "heartbeat" ,
1188+ # TODO: make this a message class
1189+ # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
1190+ payload = {
1191+ "type" : "ACK" ,
1192+ "ack" : 0 ,
1193+ },
1194+ control_flags = ACK_BIT ,
1195+ procedure_name = None ,
1196+ service_name = None ,
1197+ span = None ,
1198+ )
11971199 continue
11981200
11991201 stream = get_stream (msg .streamId )
0 commit comments