11import asyncio
22import enum
33import logging
4- from typing import Any , Callable , Coroutine
4+ from typing import Any , Awaitable , Callable , Coroutine , Protocol
55
66import nanoid # type: ignore
77import websockets
4242trace_setter = TransportMessageTracingSetter ()
4343
4444
45+ class SendMessage (Protocol ):
46+ async def __call__ (
47+ self ,
48+ * ,
49+ stream_id : str ,
50+ payload : dict [Any , Any ] | str ,
51+ control_flags : int ,
52+ service_name : str | None ,
53+ procedure_name : str | None ,
54+ span : Span | None ,
55+ ) -> None : ...
56+
57+
4558class SessionState (enum .Enum ):
4659 """The state a session can be in.
4760
@@ -53,7 +66,7 @@ class SessionState(enum.Enum):
5366 CLOSED = 2
5467
5568
56- class Session ( object ) :
69+ class Session :
5770 """A transport object that handles the websocket connection with a client."""
5871
5972 def __init__ (
@@ -106,7 +119,29 @@ def __init__(
106119 self ._setup_heartbeats_task ()
107120
108121 def _setup_heartbeats_task (self ) -> None :
109- self ._task_manager .create_task (self ._heartbeat ())
122+ async def do_close_websocket () -> None :
123+ await self .close_websocket (
124+ self ._ws_wrapper ,
125+ should_retry = not self ._is_server ,
126+ )
127+ await self ._begin_close_session_countdown ()
128+
129+ def increment_and_get_heartbeat_misses () -> int :
130+ self ._heartbeat_misses += 1
131+ return self ._heartbeat_misses
132+
133+ self ._task_manager .create_task (
134+ self ._heartbeat (
135+ self .session_id ,
136+ self ._transport_options .heartbeat_ms ,
137+ self ._transport_options .heartbeats_until_dead ,
138+ lambda : self ._state ,
139+ lambda : self ._close_session_after_time_secs ,
140+ close_websocket = do_close_websocket ,
141+ send_message = self .send_message ,
142+ increment_and_get_heartbeat_misses = increment_and_get_heartbeat_misses ,
143+ )
144+ )
110145 self ._task_manager .create_task (self ._check_to_close_session ())
111146
112147 async def is_session_open (self ) -> bool :
@@ -276,45 +311,51 @@ async def _check_to_close_session(self) -> None:
276311
277312 async def _heartbeat (
278313 self ,
314+ session_id : str ,
315+ heartbeat_ms : float ,
316+ heartbeats_until_dead : int ,
317+ get_state : Callable [[], SessionState ],
318+ get_closing_grace_period : Callable [[], float | None ],
319+ close_websocket : Callable [[], Awaitable [None ]],
320+ send_message : SendMessage ,
321+ increment_and_get_heartbeat_misses : Callable [[], int ],
279322 ) -> None :
280323 logger .debug ("Start heartbeat" )
281324 while True :
282- await asyncio .sleep (self ._transport_options .heartbeat_ms / 1000 )
283- if self ._state != SessionState .ACTIVE :
325+ await asyncio .sleep (heartbeat_ms / 1000 )
326+ state = get_state ()
327+ if state != SessionState .ACTIVE :
284328 logger .debug (
285329 "Session is closed, no need to send heartbeat, state : "
286330 "%r close_session_after_this: %r" ,
287- {self . _state },
288- {self . _close_session_after_time_secs },
331+ {state },
332+ {get_closing_grace_period () },
289333 )
290334 # session is closing / closed, no need to send heartbeat anymore
291335 return
292336 try :
293- await self . send_message (
294- "heartbeat" ,
337+ await send_message (
338+ stream_id = "heartbeat" ,
295339 # TODO: make this a message class
296340 # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
297- {
341+ payload = {
298342 "ack" : 0 ,
299343 },
300- ACK_BIT ,
344+ control_flags = ACK_BIT ,
345+ procedure_name = None ,
346+ service_name = None ,
347+ span = None ,
301348 )
302- self ._heartbeat_misses += 1
303- if (
304- self ._heartbeat_misses
305- > self ._transport_options .heartbeats_until_dead
306- ):
307- if self ._close_session_after_time_secs is not None :
349+
350+ if increment_and_get_heartbeat_misses () > heartbeats_until_dead :
351+ if get_closing_grace_period () is not None :
308352 # already in grace period, no need to set again
309353 continue
310354 logger .info (
311355 "%r closing websocket because of heartbeat misses" ,
312- self . session_id ,
356+ session_id ,
313357 )
314- await self .close_websocket (
315- self ._ws_wrapper , should_retry = not self ._is_server
316- )
317- await self ._begin_close_session_countdown ()
358+ await close_websocket ()
318359 continue
319360 except FailedSendingMessageException :
320361 # this is expected during websocket closed period
0 commit comments