2828
2929from replit_river .common_session import (
3030 SessionState ,
31+ TerminalStates ,
3132 buffered_message_sender ,
3233 check_to_close_session ,
3334 setup_heartbeat ,
@@ -103,7 +104,6 @@ class Session:
103104 _close_session_after_time_secs : float | None
104105
105106 # ws state
106- _ws_connected : bool
107107 _ws_unwrapped : ClientConnection | None
108108 _heartbeat_misses : int
109109 _retry_connection_callback : RetryConnectionCallback | None
@@ -133,12 +133,11 @@ def __init__(
133133 self ._transport_options = transport_options
134134
135135 # session state, only modified during closing
136- self ._state = SessionState .ACTIVE
136+ self ._state = SessionState .CONNECTING
137137 self ._close_session_callback = close_session_callback
138138 self ._close_session_after_time_secs : float | None = None
139139
140140 # ws state
141- self ._ws_connected = False
142141 self ._ws_unwrapped = None
143142 self ._heartbeat_misses = 0
144143 self ._retry_connection_callback = retry_connection_callback
@@ -160,18 +159,43 @@ def __init__(
160159
161160 async def do_close_websocket () -> None :
162161 logger .debug (
163- "do_close called, _ws_connected =%r, _ws_unwrapped=%r" ,
164- self ._ws_connected ,
162+ "do_close called, _state =%r, _ws_unwrapped=%r" ,
163+ self ._state ,
165164 self ._ws_unwrapped ,
166165 )
167- self ._ws_connected = False
166+ self ._state = SessionState . CLOSING
168167 if self ._ws_unwrapped :
169168 self ._task_manager .create_task (self ._ws_unwrapped .close ())
170169 if self ._retry_connection_callback :
171170 self ._task_manager .create_task (self ._retry_connection_callback ())
172171 await self ._begin_close_session_countdown ()
173172
174- self ._setup_heartbeats_task (do_close_websocket )
173+ def increment_and_get_heartbeat_misses () -> int :
174+ self ._heartbeat_misses += 1
175+ return self ._heartbeat_misses
176+
177+ self ._task_manager .create_task (
178+ setup_heartbeat (
179+ self .session_id ,
180+ self ._transport_options .heartbeat_ms ,
181+ self ._transport_options .heartbeats_until_dead ,
182+ lambda : self ._state ,
183+ lambda : self ._close_session_after_time_secs ,
184+ close_websocket = do_close_websocket ,
185+ send_message = self .send_message ,
186+ increment_and_get_heartbeat_misses = increment_and_get_heartbeat_misses ,
187+ )
188+ )
189+ self ._task_manager .create_task (
190+ check_to_close_session (
191+ self ._transport_id ,
192+ self ._transport_options .close_session_check_interval_ms ,
193+ lambda : self ._state ,
194+ self ._get_current_time ,
195+ lambda : self ._close_session_after_time_secs ,
196+ self .close ,
197+ )
198+ )
175199
176200 def commit (msg : TransportMessage ) -> None :
177201 pending = self ._send_buffer .popleft ()
@@ -193,7 +217,7 @@ def get_next_pending() -> TransportMessage | None:
193217 self ._message_enqueued ,
194218 get_ws = lambda : (
195219 cast (WebSocketCommonProtocol | ClientConnection , self ._ws_unwrapped )
196- if self .is_websocket_open ()
220+ if self .is_connected ()
197221 else None
198222 ),
199223 websocket_closed_callback = self ._begin_close_session_countdown ,
@@ -214,7 +238,7 @@ async def ensure_connected[HandshakeMetadata](
214238 Either return immediately or establish a websocket connection and return
215239 once we can accept messages
216240 """
217- if self ._ws_unwrapped and self ._ws_connected :
241+ if self ._ws_unwrapped and self ._state == SessionState . ACTIVE :
218242 return
219243 max_retry = self ._transport_options .connection_retry_options .max_retry
220244 logger .info ("Attempting to establish new ws connection" )
@@ -326,6 +350,7 @@ async def websocket_closed_callback() -> None:
326350 )
327351
328352 rate_limiter .start_restoring_budget (client_id )
353+ self ._state = SessionState .ACTIVE
329354 except RiverException as e :
330355 await ws .close ()
331356 raise e
@@ -342,44 +367,17 @@ async def websocket_closed_callback() -> None:
342367 f"Failed to create ws after retrying { max_retry } number of times" ,
343368 ) from last_error
344369
345- def _setup_heartbeats_task (
346- self ,
347- do_close_websocket : Callable [[], Awaitable [None ]],
348- ) -> None :
349- def increment_and_get_heartbeat_misses () -> int :
350- self ._heartbeat_misses += 1
351- return self ._heartbeat_misses
352-
353- self ._task_manager .create_task (
354- setup_heartbeat (
355- self .session_id ,
356- self ._transport_options .heartbeat_ms ,
357- self ._transport_options .heartbeats_until_dead ,
358- lambda : self ._state ,
359- lambda : self ._ws_connected ,
360- lambda : self ._close_session_after_time_secs ,
361- close_websocket = do_close_websocket ,
362- send_message = self .send_message ,
363- increment_and_get_heartbeat_misses = increment_and_get_heartbeat_misses ,
364- )
365- )
366- self ._task_manager .create_task (
367- check_to_close_session (
368- self ._transport_id ,
369- self ._transport_options .close_session_check_interval_ms ,
370- lambda : self ._state ,
371- self ._get_current_time ,
372- lambda : self ._close_session_after_time_secs ,
373- self .close ,
374- )
375- )
370+ def is_closed (self ) -> bool :
371+ """
372+ If the session is in a terminal state.
373+ Do not send messages, do not expect any more messages to be emitted,
374+ the state is expected to be stale.
375+ """
376+ return self ._state not in TerminalStates
376377
377- def is_session_open (self ) -> bool :
378+ def is_connected (self ) -> bool :
378379 return self ._state == SessionState .ACTIVE
379380
380- def is_websocket_open (self ) -> bool :
381- return self ._ws_connected
382-
383381 async def _begin_close_session_countdown (self ) -> None :
384382 """Begin the countdown to close session, this should be called when
385383 websocket is closed.
@@ -400,17 +398,6 @@ async def _begin_close_session_countdown(self) -> None:
400398 self ._to_id ,
401399 )
402400 self ._close_session_after_time_secs = close_session_after_time_secs
403- self ._ws_connected = False
404-
405- async def replace_with_new_websocket (self , new_ws : ClientConnection ) -> None :
406- if self ._ws_unwrapped and new_ws .id != self ._ws_unwrapped .id :
407- self ._task_manager .create_task (
408- self ._ws_unwrapped .close (
409- CloseCode .PROTOCOL_ERROR , "Transparent reconnect"
410- )
411- )
412- self ._ws_unwrapped = new_ws
413- self ._ws_connected = True
414401
415402 async def _get_current_time (self ) -> float :
416403 return asyncio .get_event_loop ().time ()
@@ -430,7 +417,7 @@ async def send_message(
430417 ) -> None :
431418 """Send serialized messages to the websockets."""
432419 # if the session is not active, we should not do anything
433- if self ._state != SessionState . ACTIVE :
420+ if self ._state in TerminalStates :
434421 return
435422 msg = TransportMessage (
436423 streamId = stream_id ,
@@ -476,7 +463,7 @@ async def close(self) -> None:
476463 f"{ self ._transport_id } closing session "
477464 f"to { self ._to_id } , ws: { self ._ws_unwrapped } "
478465 )
479- if self ._state != SessionState . ACTIVE :
466+ if self ._state in TerminalStates :
480467 # already closing
481468 return
482469 self ._state = SessionState .CLOSING
@@ -510,6 +497,8 @@ async def _serve(self) -> None:
510497 try :
511498 await self ._handle_messages_from_ws ()
512499 except ConnectionClosed :
500+ # Set ourselves to closed as soon as we get the signal
501+ self ._state = SessionState .CONNECTING
513502 if self ._retry_connection_callback :
514503 self ._task_manager .create_task (self ._retry_connection_callback ())
515504
@@ -530,7 +519,7 @@ async def _serve(self) -> None:
530519 )
531520
532521 async def _handle_messages_from_ws (self ) -> None :
533- while self ._ws_unwrapped is None or not self ._ws_connected :
522+ while self ._ws_unwrapped is None or self ._state == SessionState . CONNECTING :
534523 await asyncio .sleep (1 )
535524 logger .debug (
536525 "%s start handling messages from ws %s" ,
@@ -628,8 +617,10 @@ async def _handle_messages_from_ws(self) -> None:
628617 await self .close ()
629618 return
630619 except ConnectionClosedOK :
631- pass # Exited normally
620+ # Exited normally
621+ self ._state = SessionState .CONNECTING
632622 except ConnectionClosed as e :
623+ self ._state = SessionState .CONNECTING
633624 raise e
634625
635626 async def send_rpc [R , A ](
0 commit comments