Skip to content

Commit 907e08b

Browse files
Just use asyncio.Event to represent "connected"
1 parent a41caa2 commit 907e08b

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

src/replit_river/v2/session.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class Session:
106106
_close_session_callback: CloseSessionCallback
107107
_close_session_after_time_secs: float | None
108108
_connecting_task: asyncio.Task[Literal[True]] | None
109-
_connection_condition: asyncio.Condition
109+
_wait_for_connected: asyncio.Event
110110

111111
# ws state
112112
_ws: ClientConnection | None
@@ -145,7 +145,7 @@ def __init__(
145145
self._close_session_callback = close_session_callback
146146
self._close_session_after_time_secs: float | None = None
147147
self._connecting_task = None
148-
self._connection_condition = asyncio.Condition()
148+
self._wait_for_connected = asyncio.Event()
149149

150150
# ws state
151151
self._ws = None
@@ -208,13 +208,16 @@ def do_close() -> None:
208208
# during the cleanup procedure.
209209
self._terminating_task = asyncio.create_task(self.close())
210210

211-
async def transition_connected(ws: ClientConnection) -> None:
211+
def transition_connecting() -> None:
212+
# "Clear" here means observers should wait until we are connected.
213+
self._wait_for_connected.clear()
214+
215+
def transition_connected(ws: ClientConnection) -> None:
212216
self._state = SessionState.ACTIVE
213217
self._ws = ws
214218

215-
# We're connected, wake everybody up
216-
async with self._connection_condition:
217-
self._connection_condition.notify_all()
219+
# We're connected, wake everybody up using set()
220+
self._wait_for_connected.set()
218221

219222
def finalize_attempt() -> None:
220223
# We are in a state where we may throw an exception.
@@ -246,6 +249,7 @@ def finalize_attempt() -> None:
246249
get_next_sent_seq=get_next_sent_seq,
247250
get_current_ack=lambda: self.ack,
248251
get_current_time=self._get_current_time,
252+
transition_connecting=transition_connecting,
249253
transition_connected=transition_connected,
250254
finalize_attempt=finalize_attempt,
251255
do_close=do_close,
@@ -364,8 +368,7 @@ async def close(self) -> None:
364368
self._state = SessionState.CLOSING
365369

366370
# We need to wake up all tasks waiting for connection to be established
367-
async with self._connection_condition:
368-
self._connection_condition.notify_all()
371+
self._wait_for_connected.clear()
369372

370373
await self._task_manager.cancel_all_tasks()
371374

@@ -410,8 +413,7 @@ def get_ws() -> ClientConnection | None:
410413
return None
411414

412415
async def block_until_connected() -> None:
413-
async with self._connection_condition:
414-
await self._connection_condition.wait()
416+
await self._wait_for_connected.wait()
415417

416418
self._task_manager.create_task(
417419
_buffered_message_sender(
@@ -468,8 +470,7 @@ def increment_and_get_heartbeat_misses() -> int:
468470
return self._heartbeat_misses
469471

470472
async def block_until_connected() -> None:
471-
async with self._connection_condition:
472-
await self._connection_condition.wait()
473+
await self._wait_for_connected.wait()
473474

474475
self._task_manager.create_task(
475476
_setup_heartbeat(
@@ -535,8 +536,7 @@ def close_stream(stream_id: str) -> None:
535536
del self._streams[stream_id]
536537

537538
async def block_until_connected() -> None:
538-
async with self._connection_condition:
539-
await self._connection_condition.wait()
539+
await self._wait_for_connected.wait()
540540

541541
self._task_manager.create_task(
542542
_serve(
@@ -954,7 +954,8 @@ async def _do_ensure_connected[HandshakeMetadata](
954954
get_current_time: Callable[[], Awaitable[float]],
955955
get_next_sent_seq: Callable[[], int],
956956
get_current_ack: Callable[[], int],
957-
transition_connected: Callable[[ClientConnection], Awaitable[None]],
957+
transition_connecting: Callable[[], None],
958+
transition_connected: Callable[[ClientConnection], None],
958959
finalize_attempt: Callable[[], None],
959960
do_close: Callable[[], None],
960961
) -> Literal[True]:
@@ -968,6 +969,7 @@ async def _do_ensure_connected[HandshakeMetadata](
968969
i += 1
969970

970971
rate_limiter.consume_budget(client_id)
972+
transition_connecting()
971973

972974
ws = None
973975
try:

0 commit comments

Comments
 (0)