Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,22 @@ async def serve(self) -> None:
try:
await self._handle_messages_from_ws()
except ConnectionClosed:
if self._should_abort_streams_after_transport_failure():
await self.close()
if self._retry_connection_callback:
self._task_manager.create_task(self._retry_connection_callback())

await self._begin_close_session_countdown()
logger.debug("ConnectionClosed while serving", exc_info=True)
except FailedSendingMessageException:
# Expected error if the connection is closed.
if self._should_abort_streams_after_transport_failure():
await self.close()
logger.debug(
"FailedSendingMessageException while serving", exc_info=True
)
except Exception:
await self.close()
logger.exception("caught exception at message iterator")
except ExceptionGroup as eg:
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
Expand Down
15 changes: 12 additions & 3 deletions src/replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,17 @@ async def close_websocket(
if should_retry and self._retry_connection_callback:
self._task_manager.create_task(self._retry_connection_callback())

def _should_abort_streams_after_transport_failure(self) -> bool:
return not self._transport_options.transparent_reconnect

def _abort_all_streams(self) -> None:
"""Close all active stream channels, notifying any waiting consumers."""
if not self._streams:
return
for stream in self._streams.values():
stream.close()
self._streams.clear()

async def close(self) -> None:
"""Close the session and all associated streams."""
logger.info(
Expand All @@ -310,9 +321,7 @@ async def close(self) -> None:

# TODO: unexpected_close should close stream differently here to
# throw exception correctly.
for stream in self._streams.values():
stream.close()
self._streams.clear()
self._abort_all_streams()

self._state = SessionState.CLOSED

Expand Down
Loading