diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 465a6672..0d8ffc3c 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -16,8 +16,6 @@ from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( FailedSendingMessageException, - WebsocketClosedException, - send_transport_message, ) from replit_river.seq_manager import ( SeqManager, @@ -107,6 +105,9 @@ def __init__( self._buffer = MessageBuffer(self._transport_options.buffer_size) self._task_manager = BackgroundTaskManager() + # Start the buffered message sender task + self._start_buffered_message_sender() + def _setup_heartbeats_task( self, do_close_websocket: Callable[[], Awaitable[None]], @@ -142,6 +143,38 @@ def increment_and_get_heartbeat_misses() -> int: ) ) + def _start_buffered_message_sender(self) -> None: + """Start the background task that sends messages from the buffer.""" + from replit_river.common_session import buffered_message_sender + + async def commit(msg: TransportMessage) -> None: + # Remove messages that have been acknowledged + await self._buffer.remove_old_messages(msg.seq + 1) + + def get_next_pending() -> TransportMessage | None: + return self._buffer.peek() + + def get_ws() -> websockets.WebSocketCommonProtocol | None: + if self._ws_wrapper.is_open(): + return self._ws_wrapper.ws + return None + + async def block_until_connected() -> None: + while self._state in [SessionState.NO_CONNECTION, SessionState.CONNECTING]: + await asyncio.sleep(0.1) + + self._task_manager.create_task( + buffered_message_sender( + block_until_connected=block_until_connected, + block_until_message_available=self._buffer.block_until_message_available, + get_ws=get_ws, + websocket_closed_callback=self._begin_close_session_countdown, + get_next_pending=get_next_pending, + commit=commit, + get_state=lambda: self._state, + ) + ) + async def is_session_open(self) -> bool: async with self._state_lock: return self._state == SessionState.ACTIVE @@ -181,24 +214,6 @@ async def replace_with_new_websocket( await old_wrapper.close() self._ws_wrapper = WebsocketWrapper(new_ws) - # Send buffered messages to the new ws - buffered_messages = list(self._buffer.buffer) - for msg in buffered_messages: - try: - await send_transport_message( - msg, - new_ws, - self._begin_close_session_countdown, - ) - except WebsocketClosedException: - logger.info( - "Connection closed while sending buffered messages", exc_info=True - ) - break - except FailedSendingMessageException: - logger.exception("Error while sending buffered messages") - break - async def _get_current_time(self) -> float: return asyncio.get_event_loop().time() @@ -249,30 +264,10 @@ async def send_message( with use_span(span): trace_propagator.inject(msg, None, trace_setter) try: - try: - self._buffer.put(msg) - except MessageBufferClosedError: - # The session is closed and is no longer accepting new messages. - return - async with self._ws_lock: - if not self._ws_wrapper.is_open(): - # If the websocket is closed, we should not send the message - # and wait for the retry from the buffer. - return - await send_transport_message( - msg, self._ws_wrapper.ws, self._begin_close_session_countdown - ) - except WebsocketClosedException as e: - logger.debug( - "Connection closed while sending message %r, waiting for " - "retry from buffer", - type(e), - exc_info=e, - ) - except FailedSendingMessageException: - logger.error( - "Failed sending message, waiting for retry from buffer", exc_info=True - ) + self._buffer.put(msg) + except MessageBufferClosedError: + # The session is closed and is no longer accepting new messages. + return async def close_websocket( self, ws_wrapper: WebsocketWrapper, should_retry: bool