Skip to content

Commit 980cc17

Browse files
Switch from "queue_full Lock to space_available Event
1 parent 211bbeb commit 980cc17

File tree

2 files changed

+44
-21
lines changed

2 files changed

+44
-21
lines changed

src/replit_river/error_schema.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
# ERROR_CODE_CANCEL is the code used when either server or client cancels the stream.
1818
ERROR_CODE_CANCEL = "CANCEL"
1919

20+
# ERROR_CODE_SESSION_CLOSED is the code used when either server or client closes
21+
# the session.
22+
ERROR_CODE_SESSION_CLOSED = "CLOSED"
23+
2024
# ERROR_CODE_UNKNOWN is the code for the RiverUnknownError
2125
ERROR_CODE_UNKNOWN = "UNKNOWN"
2226

@@ -78,6 +82,16 @@ class StreamClosedRiverServiceException(RiverServiceException):
7882
pass
7983

8084

85+
class SessionClosedRiverServiceException(RiverServiceException):
86+
def __init__(
87+
self,
88+
message: str,
89+
service: str | None,
90+
procedure: str | None,
91+
) -> None:
92+
super().__init__(ERROR_CODE_SESSION_CLOSED, message, service, procedure)
93+
94+
8195
def exception_from_message(code: str) -> type[RiverServiceException]:
8296
"""Return the error class for a given error code."""
8397
if code == ERROR_CODE_STREAM_CLOSED:

src/replit_river/v2/session.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
RiverError,
4040
RiverException,
4141
RiverServiceException,
42+
SessionClosedRiverServiceException,
4243
StreamClosedRiverServiceException,
4344
exception_from_message,
4445
)
@@ -154,8 +155,10 @@ def __init__(
154155

155156
# message state
156157
self._message_enqueued = asyncio.Semaphore()
157-
self._space_available_cond = asyncio.Condition()
158-
self._queue_full_lock = asyncio.Lock()
158+
self._space_available = asyncio.Event()
159+
# Ensure we initialize the above Event to "set" to avoid being blocked from
160+
# the beginning.
161+
self._space_available.set()
159162

160163
# stream for tasks
161164
self._streams: dict[str, Channel[Any]] = {}
@@ -337,22 +340,24 @@ async def _send_message(
337340
with use_span(span):
338341
trace_propagator.inject(msg, None, trace_setter)
339342

340-
# As we prepare to push onto the buffer, if the buffer is full, we lock.
341-
# This lock will be released by the buffered_message_sender task, so it's
342-
# important that we don't release it here.
343-
#
344-
# The reason for this is that in Python, asyncio.Lock is "fair", first
345-
# come, first served.
346-
#
347-
# If somebody else is already waiting or we've filled the buffer, we
348-
# should get in line.
349-
if (
350-
self._queue_full_lock.locked()
351-
or len(self._send_buffer) >= self._transport_options.buffer_size
352-
):
353-
logger.debug("_send_message: queue full, waiting")
354-
await self._queue_full_lock.acquire()
343+
# Ensure the buffer isn't full before we enqueue
344+
await self._space_available.wait()
345+
346+
# Before we append, do an important check
347+
if self._state in TerminalStates:
348+
# session is closing / closed, raise
349+
raise SessionClosedRiverServiceException(
350+
"river session is closed, dropping message",
351+
service_name,
352+
procedure_name,
353+
)
354+
355355
self._send_buffer.append(msg)
356+
357+
# If the buffer is now full, reset the block
358+
if len(self._send_buffer) >= self._transport_options.buffer_size:
359+
self._space_available.clear()
360+
356361
# Wake up buffered_message_sender
357362
self._message_enqueued.release()
358363
self.seq += 1
@@ -368,7 +373,10 @@ async def close(self) -> None:
368373
self._state = SessionState.CLOSING
369374

370375
# We need to wake up all tasks waiting for connection to be established
371-
self._wait_for_connected.clear()
376+
self._wait_for_connected.set()
377+
378+
# We also need to wake up consumers waiting to enqueue messages
379+
self._space_available.set()
372380

373381
await self._task_manager.cancel_all_tasks()
374382

@@ -399,8 +407,7 @@ def commit(msg: TransportMessage) -> None:
399407
self._ack_buffer.append(pending)
400408

401409
# On commit, release pending writers waiting for more buffer space
402-
if self._queue_full_lock.locked():
403-
self._queue_full_lock.release()
410+
self._space_available.set()
404411

405412
def get_next_pending() -> TransportMessage | None:
406413
if self._send_buffer:
@@ -1157,7 +1164,9 @@ async def _serve(
11571164
while our_task and not our_task.cancelling() and not our_task.cancelled():
11581165
logger.debug(f"_serve loop count={idx}")
11591166
idx += 1
1160-
while (ws := get_ws()) is None or (state := get_state()) in ConnectingStates:
1167+
while (ws := get_ws()) is None or (
1168+
state := get_state()
1169+
) in ConnectingStates:
11611170
logger.debug("_handle_messages_from_ws spinning while connecting")
11621171
await block_until_connected()
11631172

0 commit comments

Comments
 (0)