@@ -102,6 +102,7 @@ class Session:
102102
103103 # ws state
104104 _ws_unwrapped : ClientConnection | None
105+ _ensure_connected_condition : asyncio .Condition
105106 _heartbeat_misses : int
106107 _retry_connection_callback : RetryConnectionCallback | None
107108
@@ -129,13 +130,14 @@ def __init__(
129130 self .session_id = session_id
130131 self ._transport_options = transport_options
131132
132- # session state, only modified during closing
133- self ._state = SessionState .CONNECTING
133+ # session state
134+ self ._state = SessionState .PENDING
134135 self ._close_session_callback = close_session_callback
135136 self ._close_session_after_time_secs : float | None = None
136137
137138 # ws state
138139 self ._ws_unwrapped = None
140+ self ._ensure_connected_condition = asyncio .Condition ()
139141 self ._heartbeat_misses = 0
140142 self ._retry_connection_callback = retry_connection_callback
141143
@@ -236,11 +238,22 @@ async def ensure_connected[HandshakeMetadata](
236238 Either return immediately or establish a websocket connection and return
237239 once we can accept messages
238240 """
239- if self ._ws_unwrapped and self ._state == SessionState .ACTIVE :
240- return
241241 max_retry = self ._transport_options .connection_retry_options .max_retry
242242 logger .info ("Attempting to establish new ws connection" )
243243
244+ if self .is_connected ():
245+ return
246+
247+ while True :
248+ await self ._ensure_connected_condition .acquire ()
249+ if self ._state == SessionState .ACTIVE :
250+ return
251+ elif self ._state == SessionState .PENDING :
252+ self ._state = SessionState .CONNECTING
253+ break
254+ elif self ._state in TerminalStates :
255+ raise RiverException ("SESSION_CLOSING" , "Going away" )
256+
244257 last_error : Exception | None = None
245258 i = 0
246259 while rate_limiter .has_budget_or_throw (client_id , ERROR_HANDSHAKE , last_error ):
@@ -349,6 +362,7 @@ async def websocket_closed_callback() -> None:
349362
350363 rate_limiter .start_restoring_budget (client_id )
351364 self ._state = SessionState .ACTIVE
365+ self ._ensure_connected_condition .notify_all ()
352366 except RiverException as e :
353367 await ws .close ()
354368 raise e
0 commit comments