Skip to content

Commit 94dbf88

Browse files
Permit sensible message_buffer state transitions
1 parent 513568e commit 94dbf88

File tree

6 files changed

+22
-17
lines changed

6 files changed

+22
-17
lines changed

src/replit_river/client_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ async def _handle_messages_from_ws(self) -> None:
140140
case other:
141141
assert_never(other)
142142

143-
self._buffer.remove_old_messages(
143+
await self._buffer.remove_old_messages(
144144
self._seq_manager.receiver_ack,
145145
)
146146
self._reset_session_close_countdown()

src/replit_river/codegen/run.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ def main() -> None:
5252
default="v1.1",
5353
choices=["v1.1", "v2.0"],
5454
)
55-
client.add_argument(
56-
"--method-filter",
57-
help="Only generate a subset of the specified methods",
58-
action="store",
59-
type=pathlib.Path,
60-
)
6155
client.add_argument("schema", help="schema file")
6256
args = parser.parse_args()
6357

@@ -82,11 +76,11 @@ def file_opener(path: Path) -> TextIO:
8276
return open(path, "w")
8377

8478
schema_to_river_client_codegen(
85-
lambda: open(schema_path),
86-
target_path,
87-
args.client_name,
88-
args.typed_dict_inputs,
89-
file_opener,
79+
read_schema=lambda: open(schema_path),
80+
target_path=target_path,
81+
client_name=args.client_name,
82+
typed_dict_inputs=args.typed_dict_inputs,
83+
file_opener=file_opener,
9084
method_filter=method_filter,
9185
protocol_version=args.protocol_version,
9286
)

src/replit_river/message_buffer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class MessageBuffer:
1717
def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE):
1818
self.max_size = max_num_messages
1919
self.buffer: list[TransportMessage] = []
20+
self._has_messages = asyncio.Event()
2021
self._space_available_cond = asyncio.Condition()
2122
self._closed = False
2223

@@ -35,6 +36,7 @@ def put(self, message: TransportMessage) -> None:
3536
if self._closed:
3637
raise MessageBufferClosedError("message buffer is closed")
3738
self.buffer.append(message)
39+
self._has_messages.set()
3840

3941
def get_next_sent_seq(self) -> int | None:
4042
if self.buffer:
@@ -47,16 +49,25 @@ def peek(self) -> TransportMessage | None:
4749
return None
4850
return self.buffer[0]
4951

50-
def remove_old_messages(self, min_seq: int) -> None:
52+
async def remove_old_messages(self, min_seq: int) -> None:
5153
"""Remove messages in the buffer with a seq number less than min_seq."""
5254
self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq]
55+
if self.buffer:
56+
self._has_messages.set()
57+
else:
58+
self._has_messages.clear()
5359
async with self._space_available_cond:
5460
self._space_available_cond.notify_all()
5561

56-
def close(self) -> None:
62+
async def block_until_message_available(self) -> None:
63+
"""Allow consumers to avoid spinning unnecessarily"""
64+
await self._has_messages.wait()
65+
66+
async def close(self) -> None:
5767
"""
5868
Closes the message buffer and rejects any pending put operations.
5969
"""
6070
self._closed = True
71+
self._has_messages.set()
6172
async with self._space_available_cond:
6273
self._space_available_cond.notify_all()

src/replit_river/server_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None:
136136
pass
137137
case other:
138138
assert_never(other)
139-
self._buffer.remove_old_messages(
139+
await self._buffer.remove_old_messages(
140140
self._seq_manager.receiver_ack,
141141
)
142142
self._reset_session_close_countdown()

src/replit_river/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ async def close(self) -> None:
307307

308308
await self.close_websocket(self._ws_wrapper, should_retry=False)
309309

310-
self._buffer.close()
310+
await self._buffer.close()
311311

312312
# Clear the session in transports
313313
await self._close_session_callback(self)

tests/test_message_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ async def put_messages() -> None:
4545
# Wait for the put call to return.
4646
await sync_events.get()
4747
assert len(buffer.buffer) == 1
48-
buffer.remove_old_messages(i)
48+
await buffer.remove_old_messages(i)
4949

5050
await background_puts
5151

0 commit comments

Comments
 (0)