Skip to content

Commit 8a3d5f3

Browse files
Break out has_capacity vs put to avoid race
1 parent 9558fd4 commit 8a3d5f3

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

src/replit_river/message_buffer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,22 @@ def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE):
2121
self._space_available_cond = asyncio.Condition()
2222
self._closed = False
2323

24-
async def put(self, message: TransportMessage) -> None:
24+
async def has_capacity(self) -> None:
25+
async with self._space_available_cond:
26+
await self._space_available_cond.wait_for(
27+
lambda: len(self.buffer) < self.max_size or self._closed
28+
)
29+
30+
def put(self, message: TransportMessage) -> None:
2531
"""Add a message to the buffer. Blocks until there is space in the buffer.
2632
2733
Raises:
2834
MessageBufferClosedError: if the buffer is closed.
2935
"""
30-
async with self._space_available_cond:
31-
await self._space_available_cond.wait_for(
32-
lambda: len(self.buffer) < self.max_size or self._closed
33-
)
34-
if self._closed:
35-
raise MessageBufferClosedError("message buffer is closed")
36-
self.buffer.append(message)
37-
self._has_messages.set()
36+
if self._closed:
37+
raise MessageBufferClosedError("message buffer is closed")
38+
self.buffer.append(message)
39+
self._has_messages.set()
3840

3941
def peek(self) -> TransportMessage | None:
4042
"""Peek the first message in the buffer, returns None if the buffer is empty."""

src/replit_river/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,9 @@ async def send_message(
255255
# We need this lock to ensure the buffer order and message sending order
256256
# are the same.
257257
async with self._msg_lock:
258+
await self._buffer.has_capacity()
258259
try:
259-
await self._buffer.put(msg)
260+
self._buffer.put(msg)
260261
except MessageBufferClosedError:
261262
# The session is closed and is no longer accepting new messages.
262263
return

tests/test_message_buffer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ async def test_message_buffer_backpressure() -> None:
3535

3636
async def put_messages() -> None:
3737
for i in range(0, iterations):
38-
await buffer.put(mock_transport_message(seq=i))
38+
await buffer.has_capacity()
39+
buffer.put(mock_transport_message(seq=i))
3940
await sync_events.put(None)
4041

4142
background_puts = asyncio.create_task(put_messages())
@@ -55,10 +56,17 @@ async def test_message_buffer_close() -> None:
5556
is closed while the put operation is waiting for space in the buffer.
5657
"""
5758
buffer = MessageBuffer(max_num_messages=1)
58-
await buffer.put(mock_transport_message(seq=1))
59-
background_put = asyncio.create_task(buffer.put(mock_transport_message(seq=1)))
59+
await buffer.has_capacity()
60+
buffer.put(mock_transport_message(seq=1))
61+
62+
async def bg_put(msg: TransportMessage) -> None:
63+
await buffer.has_capacity()
64+
buffer.put(msg)
65+
66+
background_put = asyncio.create_task(bg_put(mock_transport_message(seq=1)))
6067
await buffer.close()
6168
with pytest.raises(MessageBufferClosedError):
6269
await background_put
6370
with pytest.raises(MessageBufferClosedError):
64-
await buffer.put(mock_transport_message(seq=1))
71+
await buffer.has_capacity()
72+
buffer.put(mock_transport_message(seq=1))

0 commit comments

Comments
 (0)