Skip to content

Commit 0374959

Browse files
Do our best to avoid contention on ensure_connected
1 parent 36f852d commit 0374959

File tree

1 file changed

+53
-19
lines changed

1 file changed

+53
-19
lines changed

src/replit_river/v2/session.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ class Session:
9999
_state: SessionState
100100
_close_session_callback: CloseSessionCallback
101101
_close_session_after_time_secs: float | None
102+
_connecting_task: asyncio.Task[Literal[True]] | None
102103

103104
# ws state
104105
_ws_unwrapped: ClientConnection | None
105-
_ensure_connected_condition: asyncio.Condition
106106
_heartbeat_misses: int
107107
_retry_connection_callback: RetryConnectionCallback | None
108108

@@ -134,10 +134,10 @@ def __init__(
134134
self._state = SessionState.PENDING
135135
self._close_session_callback = close_session_callback
136136
self._close_session_after_time_secs: float | None = None
137+
self._connecting_task = None
137138

138139
# ws state
139140
self._ws_unwrapped = None
140-
self._ensure_connected_condition = asyncio.Condition()
141141
self._heartbeat_misses = 0
142142
self._retry_connection_callback = retry_connection_callback
143143

@@ -236,23 +236,38 @@ async def ensure_connected[HandshakeMetadata](
236236
) -> None:
237237
"""
238238
Either return immediately or establish a websocket connection and return
239-
once we can accept messages
239+
once we can accept messages.
240+
241+
One of the goals of this function is to gate exactly one call to the
242+
logic that actually establishes the connection.
240243
"""
241-
max_retry = self._transport_options.connection_retry_options.max_retry
242-
logger.info("Attempting to establish new ws connection")
243244

244245
if self.is_connected():
245246
return
246247

247-
while True:
248-
await self._ensure_connected_condition.acquire()
249-
if self._state == SessionState.ACTIVE:
250-
return
251-
elif self._state == SessionState.PENDING:
252-
self._state = SessionState.CONNECTING
253-
break
254-
elif self._state in TerminalStates:
255-
raise RiverException("SESSION_CLOSING", "Going away")
248+
if not self._connecting_task:
249+
self._connecting_task = asyncio.create_task(
250+
self._do_ensure_connected(
251+
client_id,
252+
rate_limiter,
253+
uri_and_metadata_factory,
254+
protocol_version,
255+
)
256+
)
257+
258+
await self._connecting_task
259+
260+
async def _do_ensure_connected[HandshakeMetadata](
261+
self,
262+
client_id: str,
263+
rate_limiter: LeakyBucketRateLimit,
264+
uri_and_metadata_factory: Callable[
265+
[], Awaitable[UriAndMetadata[HandshakeMetadata]]
266+
], # noqa: E501
267+
protocol_version: str,
268+
) -> Literal[True]:
269+
max_retry = self._transport_options.connection_retry_options.max_retry
270+
logger.info("Attempting to establish new ws connection")
256271

257272
last_error: Exception | None = None
258273
i = 0
@@ -360,9 +375,9 @@ async def websocket_closed_callback() -> None:
360375
+ f"{handshake_response.status.reason}",
361376
)
362377

378+
last_error = None
363379
rate_limiter.start_restoring_budget(client_id)
364380
self._state = SessionState.ACTIVE
365-
self._ensure_connected_condition.notify_all()
366381
except RiverException as e:
367382
await ws.close()
368383
raise e
@@ -374,10 +389,29 @@ async def websocket_closed_callback() -> None:
374389
)
375390
await asyncio.sleep(backoff_time / 1000)
376391

377-
raise RiverException(
378-
ERROR_HANDSHAKE,
379-
f"Failed to create ws after retrying {max_retry} number of times",
380-
) from last_error
392+
# We are in a state where we may throw an exception.
393+
#
394+
# To permit subsequent calls to ensure_connected to pass, we clear ourselves.
395+
# This is safe because each individual function that is waiting on this
396+
# function completeing already has a reference, so we'll last a few ticks
397+
# before GC.
398+
#
399+
# Let's do our best to avoid clobbering other tasks by comparing the .name
400+
current_task = asyncio.current_task()
401+
if (
402+
self._connecting_task
403+
and current_task
404+
and self._connecting_task.get_name() == current_task.get_name()
405+
):
406+
self._connecting_task = None
407+
408+
if last_error is not None:
409+
raise RiverException(
410+
ERROR_HANDSHAKE,
411+
f"Failed to create ws after retrying {max_retry} number of times",
412+
) from last_error
413+
414+
return True
381415

382416
def is_closed(self) -> bool:
383417
"""

0 commit comments

Comments
 (0)