Skip to content

Commit 0c07f6f

Browse files
Switch message_enqueued semaphore to Event to avoid out-of-sync bugs
Semaphore length and _send_buffer were maintained 1:1, but it still left the opportunity for bugs in the future. Switching to an Event lets us only ever care about the length of the _send_buffer.
1 parent 980cc17 commit 0c07f6f

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

src/replit_river/v2/session.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ class Session:
114114
_heartbeat_misses: int
115115
_retry_connection_callback: RetryConnectionCallback | None
116116

117+
# message state
118+
_process_messages: asyncio.Event
119+
_space_available: asyncio.Event
120+
117121
# stream for tasks
118122
_streams: dict[str, Channel[Any]]
119123

@@ -154,7 +158,7 @@ def __init__(
154158
self._retry_connection_callback = retry_connection_callback
155159

156160
# message state
157-
self._message_enqueued = asyncio.Semaphore()
161+
self._process_messages = asyncio.Event()
158162
self._space_available = asyncio.Event()
159163
# Ensure we initialize the above Event to "set" to avoid being blocked from
160164
# the beginning.
@@ -359,7 +363,7 @@ async def _send_message(
359363
self._space_available.clear()
360364

361365
# Wake up buffered_message_sender
362-
self._message_enqueued.release()
366+
self._process_messages.set()
363367
self.seq += 1
364368

365369
async def close(self) -> None:
@@ -372,11 +376,13 @@ async def close(self) -> None:
372376
return
373377
self._state = SessionState.CLOSING
374378

375-
# We need to wake up all tasks waiting for connection to be established
379+
# We're closing, so we need to wake up...
380+
# ... tasks waiting for connection to be established
376381
self._wait_for_connected.set()
377-
378-
# We also need to wake up consumers waiting to enqueue messages
382+
# ... consumers waiting to enqueue messages
379383
self._space_available.set()
384+
# ... message processor so it can exit cleanly
385+
self._process_messages.set()
380386

381387
await self._task_manager.cancel_all_tasks()
382388

@@ -406,8 +412,12 @@ def commit(msg: TransportMessage) -> None:
406412
logger.error("Out of sequence error")
407413
self._ack_buffer.append(pending)
408414

409-
# On commit, release pending writers waiting for more buffer space
415+
# On commit...
416+
# ... release pending writers waiting for more buffer space
410417
self._space_available.set()
418+
# ... tell the message sender to back off if there are no pending messages
419+
if not self._send_buffer:
420+
self._process_messages.clear()
411421

412422
def get_next_pending() -> TransportMessage | None:
413423
if self._send_buffer:
@@ -422,10 +432,13 @@ def get_ws() -> ClientConnection | None:
422432
async def block_until_connected() -> None:
423433
await self._wait_for_connected.wait()
424434

435+
async def block_until_message_available() -> None:
436+
await self._process_messages.wait()
437+
425438
self._task_manager.create_task(
426439
_buffered_message_sender(
427440
block_until_connected=block_until_connected,
428-
message_enqueued=self._message_enqueued,
441+
block_until_message_available=block_until_message_available,
429442
get_ws=get_ws,
430443
websocket_closed_callback=self._begin_close_session_countdown,
431444
get_next_pending=get_next_pending,
@@ -865,7 +878,7 @@ async def send_close_stream(
865878

866879
async def _buffered_message_sender(
867880
block_until_connected: Callable[[], Awaitable[None]],
868-
message_enqueued: asyncio.Semaphore,
881+
block_until_message_available: Callable[[], Awaitable[None]],
869882
get_ws: Callable[[], ClientConnection | None],
870883
websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]],
871884
get_next_pending: Callable[[], TransportMessage | None],
@@ -874,18 +887,19 @@ async def _buffered_message_sender(
874887
) -> None:
875888
our_task = asyncio.current_task()
876889
while our_task and not our_task.cancelling() and not our_task.cancelled():
877-
await message_enqueued.acquire()
890+
await block_until_message_available()
891+
892+
if get_state() in TerminalStates:
893+
logger.debug("buffered_message_sender: closing")
894+
return
895+
878896
while (ws := get_ws()) is None:
879897
# Block until we have a handle
880898
logger.debug(
881899
"buffered_message_sender: Waiting until ws is connected",
882900
)
883901
await block_until_connected()
884902

885-
if get_state() in TerminalStates:
886-
logger.debug("We're going away!")
887-
return
888-
889903
if not ws:
890904
logger.debug("ws is not connected, loop")
891905
continue
@@ -906,18 +920,15 @@ async def _buffered_message_sender(
906920
type(e),
907921
exc_info=e,
908922
)
909-
message_enqueued.release()
910923
break
911924
except FailedSendingMessageException:
912925
logger.error(
913926
"Failed sending message, waiting for retry from buffer",
914927
exc_info=True,
915928
)
916-
message_enqueued.release()
917929
break
918930
except Exception:
919931
logger.exception("Error attempting to send buffered messages")
920-
message_enqueued.release()
921932
break
922933

923934

0 commit comments

Comments
 (0)