Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 39 additions & 44 deletions src/replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
from replit_river.messages import (
FailedSendingMessageException,
WebsocketClosedException,
send_transport_message,
)
from replit_river.seq_manager import (
SeqManager,
Expand Down Expand Up @@ -107,6 +105,9 @@ def __init__(
self._buffer = MessageBuffer(self._transport_options.buffer_size)
self._task_manager = BackgroundTaskManager()

# Start the buffered message sender task
self._start_buffered_message_sender()

def _setup_heartbeats_task(
self,
do_close_websocket: Callable[[], Awaitable[None]],
Expand Down Expand Up @@ -142,6 +143,38 @@ def increment_and_get_heartbeat_misses() -> int:
)
)

def _start_buffered_message_sender(self) -> None:
"""Start the background task that sends messages from the buffer."""
from replit_river.common_session import buffered_message_sender

async def commit(msg: TransportMessage) -> None:
# Remove messages that have been acknowledged
await self._buffer.remove_old_messages(msg.seq + 1)

def get_next_pending() -> TransportMessage | None:
return self._buffer.peek()

def get_ws() -> websockets.WebSocketCommonProtocol | None:
if self._ws_wrapper.is_open():
return self._ws_wrapper.ws
return None

async def block_until_connected() -> None:
while self._state in [SessionState.NO_CONNECTION, SessionState.CONNECTING]:
await asyncio.sleep(0.1)

self._task_manager.create_task(
buffered_message_sender(
block_until_connected=block_until_connected,
block_until_message_available=self._buffer.block_until_message_available,
get_ws=get_ws,
websocket_closed_callback=self._begin_close_session_countdown,
get_next_pending=get_next_pending,
commit=commit,
get_state=lambda: self._state,
)
)

async def is_session_open(self) -> bool:
async with self._state_lock:
return self._state == SessionState.ACTIVE
Expand Down Expand Up @@ -181,24 +214,6 @@ async def replace_with_new_websocket(
await old_wrapper.close()
self._ws_wrapper = WebsocketWrapper(new_ws)

# Send buffered messages to the new ws
buffered_messages = list(self._buffer.buffer)
for msg in buffered_messages:
try:
await send_transport_message(
msg,
new_ws,
self._begin_close_session_countdown,
)
except WebsocketClosedException:
logger.info(
"Connection closed while sending buffered messages", exc_info=True
)
break
except FailedSendingMessageException:
logger.exception("Error while sending buffered messages")
break

async def _get_current_time(self) -> float:
return asyncio.get_event_loop().time()

Expand Down Expand Up @@ -249,30 +264,10 @@ async def send_message(
with use_span(span):
trace_propagator.inject(msg, None, trace_setter)
try:
try:
self._buffer.put(msg)
except MessageBufferClosedError:
# The session is closed and is no longer accepting new messages.
return
async with self._ws_lock:
if not self._ws_wrapper.is_open():
# If the websocket is closed, we should not send the message
# and wait for the retry from the buffer.
return
await send_transport_message(
msg, self._ws_wrapper.ws, self._begin_close_session_countdown
)
except WebsocketClosedException as e:
logger.debug(
"Connection closed while sending message %r, waiting for "
"retry from buffer",
type(e),
exc_info=e,
)
except FailedSendingMessageException:
logger.error(
"Failed sending message, waiting for retry from buffer", exc_info=True
)
self._buffer.put(msg)
except MessageBufferClosedError:
# The session is closed and is no longer accepting new messages.
return

async def close_websocket(
self, ws_wrapper: WebsocketWrapper, should_retry: bool
Expand Down
Loading