@@ -99,10 +99,10 @@ class Session:
9999 _state : SessionState
100100 _close_session_callback : CloseSessionCallback
101101 _close_session_after_time_secs : float | None
102+ _connecting_task : asyncio .Task [Literal [True ]] | None
102103
103104 # ws state
104105 _ws_unwrapped : ClientConnection | None
105- _ensure_connected_condition : asyncio .Condition
106106 _heartbeat_misses : int
107107 _retry_connection_callback : RetryConnectionCallback | None
108108
@@ -134,10 +134,10 @@ def __init__(
134134 self ._state = SessionState .PENDING
135135 self ._close_session_callback = close_session_callback
136136 self ._close_session_after_time_secs : float | None = None
137+ self ._connecting_task = None
137138
138139 # ws state
139140 self ._ws_unwrapped = None
140- self ._ensure_connected_condition = asyncio .Condition ()
141141 self ._heartbeat_misses = 0
142142 self ._retry_connection_callback = retry_connection_callback
143143
@@ -236,23 +236,38 @@ async def ensure_connected[HandshakeMetadata](
236236 ) -> None :
237237 """
238238 Either return immediately or establish a websocket connection and return
239- once we can accept messages
239+ once we can accept messages.
240+
241+ One of the goals of this function is to gate exactly one call to the
242+ logic that actually establishes the connection.
240243 """
241- max_retry = self ._transport_options .connection_retry_options .max_retry
242- logger .info ("Attempting to establish new ws connection" )
243244
244245 if self .is_connected ():
245246 return
246247
247- while True :
248- await self ._ensure_connected_condition .acquire ()
249- if self ._state == SessionState .ACTIVE :
250- return
251- elif self ._state == SessionState .PENDING :
252- self ._state = SessionState .CONNECTING
253- break
254- elif self ._state in TerminalStates :
255- raise RiverException ("SESSION_CLOSING" , "Going away" )
248+ if not self ._connecting_task :
249+ self ._connecting_task = asyncio .create_task (
250+ self ._do_ensure_connected (
251+ client_id ,
252+ rate_limiter ,
253+ uri_and_metadata_factory ,
254+ protocol_version ,
255+ )
256+ )
257+
258+ await self ._connecting_task
259+
260+ async def _do_ensure_connected [HandshakeMetadata ](
261+ self ,
262+ client_id : str ,
263+ rate_limiter : LeakyBucketRateLimit ,
264+ uri_and_metadata_factory : Callable [
265+ [], Awaitable [UriAndMetadata [HandshakeMetadata ]]
266+ ], # noqa: E501
267+ protocol_version : str ,
268+ ) -> Literal [True ]:
269+ max_retry = self ._transport_options .connection_retry_options .max_retry
270+ logger .info ("Attempting to establish new ws connection" )
256271
257272 last_error : Exception | None = None
258273 i = 0
@@ -360,9 +375,9 @@ async def websocket_closed_callback() -> None:
360375 + f"{ handshake_response .status .reason } " ,
361376 )
362377
378+ last_error = None
363379 rate_limiter .start_restoring_budget (client_id )
364380 self ._state = SessionState .ACTIVE
365- self ._ensure_connected_condition .notify_all ()
366381 except RiverException as e :
367382 await ws .close ()
368383 raise e
@@ -374,10 +389,29 @@ async def websocket_closed_callback() -> None:
374389 )
375390 await asyncio .sleep (backoff_time / 1000 )
376391
377- raise RiverException (
378- ERROR_HANDSHAKE ,
379- f"Failed to create ws after retrying { max_retry } number of times" ,
380- ) from last_error
392+ # We are in a state where we may throw an exception.
393+ #
394+ # To permit subsequent calls to ensure_connected to pass, we clear ourselves.
395+ # This is safe because each individual function that is waiting on this
396+ # function completeing already has a reference, so we'll last a few ticks
397+ # before GC.
398+ #
399+ # Let's do our best to avoid clobbering other tasks by comparing the .name
400+ current_task = asyncio .current_task ()
401+ if (
402+ self ._connecting_task
403+ and current_task
404+ and self ._connecting_task .get_name () == current_task .get_name ()
405+ ):
406+ self ._connecting_task = None
407+
408+ if last_error is not None :
409+ raise RiverException (
410+ ERROR_HANDSHAKE ,
411+ f"Failed to create ws after retrying { max_retry } number of times" ,
412+ ) from last_error
413+
414+ return True
381415
382416 def is_closed (self ) -> bool :
383417 """
0 commit comments