diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index ed498847..d703c87b 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -99,6 +99,8 @@ async def serve(self) -> None: try: await self._handle_messages_from_ws() except ConnectionClosed: + if self._should_abort_streams_after_transport_failure(): + await self.close() if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) @@ -106,10 +108,13 @@ async def serve(self) -> None: logger.debug("ConnectionClosed while serving", exc_info=True) except FailedSendingMessageException: # Expected error if the connection is closed. + if self._should_abort_streams_after_transport_failure(): + await self.close() logger.debug( "FailedSendingMessageException while serving", exc_info=True ) except Exception: + await self.close() logger.exception("caught exception at message iterator") except ExceptionGroup as eg: _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 465a6672..b4240b1a 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -286,6 +286,17 @@ async def close_websocket( if should_retry and self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) + def _should_abort_streams_after_transport_failure(self) -> bool: + return not self._transport_options.transparent_reconnect + + def _abort_all_streams(self) -> None: + """Close all active stream channels, notifying any waiting consumers.""" + if not self._streams: + return + for stream in self._streams.values(): + stream.close() + self._streams.clear() + async def close(self) -> None: """Close the session and all associated streams.""" logger.info( @@ -310,9 +321,7 @@ async def close(self) -> None: # TODO: unexpected_close should close stream differently here to # throw exception correctly. - for stream in self._streams.values(): - stream.close() - self._streams.clear() + self._abort_all_streams() self._state = SessionState.CLOSED