Skip to content

Commit 36f852d

Browse files
Block ensure_connected until connected
1 parent bef6d9e commit 36f852d

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

src/replit_river/common_session.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ class SessionState(enum.Enum):
4141
- CLOSING -> {CLOSED}
4242
"""
4343

44-
CONNECTING = 0
45-
ACTIVE = 1
46-
CLOSING = 2
47-
CLOSED = 3
44+
PENDING = 0
45+
CONNECTING = 1
46+
ACTIVE = 2
47+
CLOSING = 3
48+
CLOSED = 4
4849

4950

5051
TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED])

src/replit_river/v2/session.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class Session:
102102

103103
# ws state
104104
_ws_unwrapped: ClientConnection | None
105+
_ensure_connected_condition: asyncio.Condition
105106
_heartbeat_misses: int
106107
_retry_connection_callback: RetryConnectionCallback | None
107108

@@ -129,13 +130,14 @@ def __init__(
129130
self.session_id = session_id
130131
self._transport_options = transport_options
131132

132-
# session state, only modified during closing
133-
self._state = SessionState.CONNECTING
133+
# session state
134+
self._state = SessionState.PENDING
134135
self._close_session_callback = close_session_callback
135136
self._close_session_after_time_secs: float | None = None
136137

137138
# ws state
138139
self._ws_unwrapped = None
140+
self._ensure_connected_condition = asyncio.Condition()
139141
self._heartbeat_misses = 0
140142
self._retry_connection_callback = retry_connection_callback
141143

@@ -236,11 +238,22 @@ async def ensure_connected[HandshakeMetadata](
236238
Either return immediately or establish a websocket connection and return
237239
once we can accept messages
238240
"""
239-
if self._ws_unwrapped and self._state == SessionState.ACTIVE:
240-
return
241241
max_retry = self._transport_options.connection_retry_options.max_retry
242242
logger.info("Attempting to establish new ws connection")
243243

244+
if self.is_connected():
245+
return
246+
247+
while True:
248+
await self._ensure_connected_condition.acquire()
249+
if self._state == SessionState.ACTIVE:
250+
return
251+
elif self._state == SessionState.PENDING:
252+
self._state = SessionState.CONNECTING
253+
break
254+
elif self._state in TerminalStates:
255+
raise RiverException("SESSION_CLOSING", "Going away")
256+
244257
last_error: Exception | None = None
245258
i = 0
246259
while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error):
@@ -349,6 +362,7 @@ async def websocket_closed_callback() -> None:
349362

350363
rate_limiter.start_restoring_budget(client_id)
351364
self._state = SessionState.ACTIVE
365+
self._ensure_connected_condition.notify_all()
352366
except RiverException as e:
353367
await ws.close()
354368
raise e

0 commit comments

Comments
 (0)