Skip to content

Commit 583a951

Browse files
Permit sensible message_buffer state transitions
1 parent a4574e2 commit 583a951

File tree

5 files changed

+18
-7
lines changed

5 files changed

+18
-7
lines changed

src/replit_river/client_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ async def _handle_messages_from_ws(self) -> None:
140140
case other:
141141
assert_never(other)
142142

143-
self._buffer.remove_old_messages(
143+
await self._buffer.remove_old_messages(
144144
self._seq_manager.receiver_ack,
145145
)
146146
self._reset_session_close_countdown()

src/replit_river/message_buffer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class MessageBuffer:
1717
def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE):
1818
self.max_size = max_num_messages
1919
self.buffer: list[TransportMessage] = []
20+
self._has_messages = asyncio.Event()
2021
self._space_available_cond = asyncio.Condition()
2122
self._closed = False
2223

@@ -33,23 +34,33 @@ async def put(self, message: TransportMessage) -> None:
3334
if self._closed:
3435
raise MessageBufferClosedError("message buffer is closed")
3536
self.buffer.append(message)
37+
self._has_messages.set()
3638

3739
def peek(self) -> TransportMessage | None:
3840
"""Peek the first message in the buffer, returns None if the buffer is empty."""
3941
if len(self.buffer) == 0:
4042
return None
4143
return self.buffer[0]
4244

43-
def remove_old_messages(self, min_seq: int) -> None:
45+
async def remove_old_messages(self, min_seq: int) -> None:
4446
"""Remove messages in the buffer with a seq number less than min_seq."""
4547
self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq]
48+
if self.buffer:
49+
self._has_messages.set()
50+
else:
51+
self._has_messages.clear()
4652
async with self._space_available_cond:
4753
self._space_available_cond.notify_all()
4854

49-
def close(self) -> None:
55+
async def block_until_message_available(self) -> None:
56+
"""Allow consumers to avoid spinning unnecessarily"""
57+
await self._has_messages.wait()
58+
59+
async def close(self) -> None:
5060
"""
5161
Closes the message buffer and rejects any pending put operations.
5262
"""
5363
self._closed = True
64+
self._has_messages.set()
5465
async with self._space_available_cond:
5566
self._space_available_cond.notify_all()

src/replit_river/server_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None:
136136
pass
137137
case other:
138138
assert_never(other)
139-
self._buffer.remove_old_messages(
139+
await self._buffer.remove_old_messages(
140140
self._seq_manager.receiver_ack,
141141
)
142142
self._reset_session_close_countdown()

src/replit_river/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ async def close(self) -> None:
325325

326326
await self.close_websocket(self._ws_wrapper, should_retry=False)
327327

328-
self._buffer.close()
328+
await self._buffer.close()
329329

330330
# Clear the session in transports
331331
await self._close_session_callback(self)

tests/test_message_buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def put_messages() -> None:
4444
# Wait for the put call to return.
4545
await sync_events.get()
4646
assert len(buffer.buffer) == 1
47-
buffer.remove_old_messages(i)
47+
await buffer.remove_old_messages(i)
4848

4949
await background_puts
5050

@@ -57,7 +57,7 @@ async def test_message_buffer_close() -> None:
5757
buffer = MessageBuffer(max_num_messages=1)
5858
await buffer.put(mock_transport_message(seq=1))
5959
background_put = asyncio.create_task(buffer.put(mock_transport_message(seq=1)))
60-
buffer.close()
60+
await buffer.close()
6161
with pytest.raises(MessageBufferClosedError):
6262
await background_put
6363
with pytest.raises(MessageBufferClosedError):

0 commit comments

Comments
 (0)