Skip to content

Commit 7caf0d5

Browse files
Just use _state instead of having two
1 parent 34ef3f5 commit 7caf0d5

File tree

4 files changed

+72
-72
lines changed

4 files changed

+72
-72
lines changed

src/replit_river/common_session.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,26 @@ async def __call__(
3535
class SessionState(enum.Enum):
3636
"""The state a session can be in.
3737
38-
Can only transition from ACTIVE to CLOSING to CLOSED.
38+
Valid transitions:
39+
- CONNECTING -> {ACTIVE, CLOSING}
40+
- ACTIVE -> {CONNECTING, CLOSING}
41+
- CLOSING -> {CLOSED}
3942
"""
4043

41-
ACTIVE = 0
42-
CLOSING = 1
43-
CLOSED = 2
44+
CONNECTING = 0
45+
ACTIVE = 1
46+
CLOSING = 2
47+
CLOSED = 3
48+
49+
50+
TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED])
4451

4552

4653
async def setup_heartbeat(
4754
session_id: str,
4855
heartbeat_ms: float,
4956
heartbeats_until_dead: int,
5057
get_state: Callable[[], SessionState],
51-
get_connected: Callable[[], bool],
5258
get_closing_grace_period: Callable[[], float | None],
5359
close_websocket: Callable[[], Awaitable[None]],
5460
send_message: SendMessage,
@@ -58,10 +64,10 @@ async def setup_heartbeat(
5864
while True:
5965
await asyncio.sleep(heartbeat_ms / 1000)
6066
state = get_state()
61-
if not get_connected():
67+
if state == SessionState.CONNECTING:
6268
logger.debug("Websocket is not connected, not sending heartbeat")
6369
continue
64-
if state != SessionState.ACTIVE:
70+
if state in TerminalStates:
6571
logger.debug(
6672
"Session is closed, no need to send heartbeat, state : "
6773
"%r close_session_after_this: %r",
@@ -110,7 +116,7 @@ async def check_to_close_session(
110116
) -> None:
111117
while True:
112118
await asyncio.sleep(close_session_check_interval_ms / 1000)
113-
if get_state() != SessionState.ACTIVE:
119+
if get_state() in TerminalStates:
114120
# already closing
115121
return
116122
# calculate the value now before comparing it so that there are no

src/replit_river/session.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Session:
5151
session_id: str
5252
_transport_options: TransportOptions
5353

54-
# session state, only modified during closing
54+
# session state
5555
_state: SessionState
5656
_state_lock: asyncio.Lock
5757
_close_session_callback: CloseSessionCallback
@@ -88,7 +88,7 @@ def __init__(
8888
self.session_id = session_id
8989
self._transport_options = transport_options
9090

91-
# session state, only modified during closing
91+
# session state
9292
self._state = SessionState.ACTIVE
9393
self._state_lock = asyncio.Lock()
9494
self._close_session_callback = close_session_callback
@@ -123,8 +123,11 @@ def increment_and_get_heartbeat_misses() -> int:
123123
self.session_id,
124124
self._transport_options.heartbeat_ms,
125125
self._transport_options.heartbeats_until_dead,
126-
lambda: self._state,
127-
lambda: self._ws_wrapper.ws_state == WsState.OPEN,
126+
lambda: (
127+
self._state
128+
if self._ws_wrapper.ws_state == WsState.OPEN
129+
else SessionState.CONNECTING
130+
),
128131
lambda: self._close_session_after_time_secs,
129132
close_websocket=do_close_websocket,
130133
send_message=self.send_message,

src/replit_river/v2/client_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ async def get_or_create_session(self) -> Session:
6666
call ensure_connected on whatever session is active.
6767
"""
6868
existing_session = self._session
69-
if not existing_session or not existing_session.is_session_open():
69+
if not existing_session or existing_session.is_closed():
7070
logger.info("Creating new session")
7171
new_session = Session(
7272
transport_id=self._transport_id,

src/replit_river/v2/session.py

Lines changed: 50 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from replit_river.common_session import (
3030
SessionState,
31+
TerminalStates,
3132
buffered_message_sender,
3233
check_to_close_session,
3334
setup_heartbeat,
@@ -103,7 +104,6 @@ class Session:
103104
_close_session_after_time_secs: float | None
104105

105106
# ws state
106-
_ws_connected: bool
107107
_ws_unwrapped: ClientConnection | None
108108
_heartbeat_misses: int
109109
_retry_connection_callback: RetryConnectionCallback | None
@@ -133,12 +133,11 @@ def __init__(
133133
self._transport_options = transport_options
134134

135135
# session state, only modified during closing
136-
self._state = SessionState.ACTIVE
136+
self._state = SessionState.CONNECTING
137137
self._close_session_callback = close_session_callback
138138
self._close_session_after_time_secs: float | None = None
139139

140140
# ws state
141-
self._ws_connected = False
142141
self._ws_unwrapped = None
143142
self._heartbeat_misses = 0
144143
self._retry_connection_callback = retry_connection_callback
@@ -160,18 +159,43 @@ def __init__(
160159

161160
async def do_close_websocket() -> None:
162161
logger.debug(
163-
"do_close called, _ws_connected=%r, _ws_unwrapped=%r",
164-
self._ws_connected,
162+
"do_close called, _state=%r, _ws_unwrapped=%r",
163+
self._state,
165164
self._ws_unwrapped,
166165
)
167-
self._ws_connected = False
166+
self._state = SessionState.CLOSING
168167
if self._ws_unwrapped:
169168
self._task_manager.create_task(self._ws_unwrapped.close())
170169
if self._retry_connection_callback:
171170
self._task_manager.create_task(self._retry_connection_callback())
172171
await self._begin_close_session_countdown()
173172

174-
self._setup_heartbeats_task(do_close_websocket)
173+
def increment_and_get_heartbeat_misses() -> int:
174+
self._heartbeat_misses += 1
175+
return self._heartbeat_misses
176+
177+
self._task_manager.create_task(
178+
setup_heartbeat(
179+
self.session_id,
180+
self._transport_options.heartbeat_ms,
181+
self._transport_options.heartbeats_until_dead,
182+
lambda: self._state,
183+
lambda: self._close_session_after_time_secs,
184+
close_websocket=do_close_websocket,
185+
send_message=self.send_message,
186+
increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses,
187+
)
188+
)
189+
self._task_manager.create_task(
190+
check_to_close_session(
191+
self._transport_id,
192+
self._transport_options.close_session_check_interval_ms,
193+
lambda: self._state,
194+
self._get_current_time,
195+
lambda: self._close_session_after_time_secs,
196+
self.close,
197+
)
198+
)
175199

176200
def commit(msg: TransportMessage) -> None:
177201
pending = self._send_buffer.popleft()
@@ -193,7 +217,7 @@ def get_next_pending() -> TransportMessage | None:
193217
self._message_enqueued,
194218
get_ws=lambda: (
195219
cast(WebSocketCommonProtocol | ClientConnection, self._ws_unwrapped)
196-
if self.is_websocket_open()
220+
if self.is_connected()
197221
else None
198222
),
199223
websocket_closed_callback=self._begin_close_session_countdown,
@@ -214,7 +238,7 @@ async def ensure_connected[HandshakeMetadata](
214238
Either return immediately or establish a websocket connection and return
215239
once we can accept messages
216240
"""
217-
if self._ws_unwrapped and self._ws_connected:
241+
if self._ws_unwrapped and self._state == SessionState.ACTIVE:
218242
return
219243
max_retry = self._transport_options.connection_retry_options.max_retry
220244
logger.info("Attempting to establish new ws connection")
@@ -326,6 +350,7 @@ async def websocket_closed_callback() -> None:
326350
)
327351

328352
rate_limiter.start_restoring_budget(client_id)
353+
self._state = SessionState.ACTIVE
329354
except RiverException as e:
330355
await ws.close()
331356
raise e
@@ -342,44 +367,17 @@ async def websocket_closed_callback() -> None:
342367
f"Failed to create ws after retrying {max_retry} number of times",
343368
) from last_error
344369

345-
def _setup_heartbeats_task(
346-
self,
347-
do_close_websocket: Callable[[], Awaitable[None]],
348-
) -> None:
349-
def increment_and_get_heartbeat_misses() -> int:
350-
self._heartbeat_misses += 1
351-
return self._heartbeat_misses
352-
353-
self._task_manager.create_task(
354-
setup_heartbeat(
355-
self.session_id,
356-
self._transport_options.heartbeat_ms,
357-
self._transport_options.heartbeats_until_dead,
358-
lambda: self._state,
359-
lambda: self._ws_connected,
360-
lambda: self._close_session_after_time_secs,
361-
close_websocket=do_close_websocket,
362-
send_message=self.send_message,
363-
increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses,
364-
)
365-
)
366-
self._task_manager.create_task(
367-
check_to_close_session(
368-
self._transport_id,
369-
self._transport_options.close_session_check_interval_ms,
370-
lambda: self._state,
371-
self._get_current_time,
372-
lambda: self._close_session_after_time_secs,
373-
self.close,
374-
)
375-
)
370+
def is_closed(self) -> bool:
371+
"""
372+
If the session is in a terminal state.
373+
Do not send messages, do not expect any more messages to be emitted,
374+
the state is expected to be stale.
375+
"""
376+
return self._state not in TerminalStates
376377

377-
def is_session_open(self) -> bool:
378+
def is_connected(self) -> bool:
378379
return self._state == SessionState.ACTIVE
379380

380-
def is_websocket_open(self) -> bool:
381-
return self._ws_connected
382-
383381
async def _begin_close_session_countdown(self) -> None:
384382
"""Begin the countdown to close session, this should be called when
385383
websocket is closed.
@@ -400,17 +398,6 @@ async def _begin_close_session_countdown(self) -> None:
400398
self._to_id,
401399
)
402400
self._close_session_after_time_secs = close_session_after_time_secs
403-
self._ws_connected = False
404-
405-
async def replace_with_new_websocket(self, new_ws: ClientConnection) -> None:
406-
if self._ws_unwrapped and new_ws.id != self._ws_unwrapped.id:
407-
self._task_manager.create_task(
408-
self._ws_unwrapped.close(
409-
CloseCode.PROTOCOL_ERROR, "Transparent reconnect"
410-
)
411-
)
412-
self._ws_unwrapped = new_ws
413-
self._ws_connected = True
414401

415402
async def _get_current_time(self) -> float:
416403
return asyncio.get_event_loop().time()
@@ -430,7 +417,7 @@ async def send_message(
430417
) -> None:
431418
"""Send serialized messages to the websockets."""
432419
# if the session is not active, we should not do anything
433-
if self._state != SessionState.ACTIVE:
420+
if self._state in TerminalStates:
434421
return
435422
msg = TransportMessage(
436423
streamId=stream_id,
@@ -476,7 +463,7 @@ async def close(self) -> None:
476463
f"{self._transport_id} closing session "
477464
f"to {self._to_id}, ws: {self._ws_unwrapped}"
478465
)
479-
if self._state != SessionState.ACTIVE:
466+
if self._state in TerminalStates:
480467
# already closing
481468
return
482469
self._state = SessionState.CLOSING
@@ -510,6 +497,8 @@ async def _serve(self) -> None:
510497
try:
511498
await self._handle_messages_from_ws()
512499
except ConnectionClosed:
500+
# Set ourselves to closed as soon as we get the signal
501+
self._state = SessionState.CONNECTING
513502
if self._retry_connection_callback:
514503
self._task_manager.create_task(self._retry_connection_callback())
515504

@@ -530,7 +519,7 @@ async def _serve(self) -> None:
530519
)
531520

532521
async def _handle_messages_from_ws(self) -> None:
533-
while self._ws_unwrapped is None or not self._ws_connected:
522+
while self._ws_unwrapped is None or self._state == SessionState.CONNECTING:
534523
await asyncio.sleep(1)
535524
logger.debug(
536525
"%s start handling messages from ws %s",
@@ -628,8 +617,10 @@ async def _handle_messages_from_ws(self) -> None:
628617
await self.close()
629618
return
630619
except ConnectionClosedOK:
631-
pass # Exited normally
620+
# Exited normally
621+
self._state = SessionState.CONNECTING
632622
except ConnectionClosed as e:
623+
self._state = SessionState.CONNECTING
633624
raise e
634625

635626
async def send_rpc[R, A](

0 commit comments

Comments
 (0)