Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,22 @@ async def serve(self) -> None:
try:
await self._handle_messages_from_ws()
except ConnectionClosed:
if self._should_abort_streams_after_transport_failure():
self._abort_all_streams()
if self._retry_connection_callback:
self._task_manager.create_task(self._retry_connection_callback())

await self._begin_close_session_countdown()
logger.debug("ConnectionClosed while serving", exc_info=True)
except FailedSendingMessageException:
# Expected error if the connection is closed.
if self._should_abort_streams_after_transport_failure():
self._abort_all_streams()
logger.debug(
"FailedSendingMessageException while serving", exc_info=True
)
except Exception:
self._abort_all_streams()
logger.exception("caught exception at message iterator")
except ExceptionGroup as eg:
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
Expand Down
5 changes: 5 additions & 0 deletions src/replit_river/server_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ async def serve(self) -> None:
try:
await self._handle_messages_from_ws(tg)
except ConnectionClosed:
if self._should_abort_streams_after_transport_failure():
self._abort_all_streams()
if self._retry_connection_callback:
self._task_manager.create_task(
self._retry_connection_callback()
Expand All @@ -96,10 +98,13 @@ async def serve(self) -> None:
logger.debug("ConnectionClosed while serving", exc_info=True)
except FailedSendingMessageException:
# Expected error if the connection is closed.
if self._should_abort_streams_after_transport_failure():
self._abort_all_streams()
logger.debug(
"FailedSendingMessageException while serving", exc_info=True
)
except Exception:
self._abort_all_streams()
logger.exception("caught exception at message iterator")
except ExceptionGroup as eg:
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
Expand Down
44 changes: 31 additions & 13 deletions src/replit_river/server_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,35 @@ async def _get_or_create_session(
close_session_callback=self._delete_session,
)
else:
# If the instance id is the same, we reuse the session and assign
# a new websocket to it.
logger.debug(
'Reuse old session with "%s" using new ws: %s',
to_id,
websocket.id,
)
try:
await old_session.replace_with_new_websocket(websocket)
new_session = old_session
except FailedSendingMessageException as e:
raise e
if not await old_session.is_session_open():
logger.info(
'Session "%s" is not active, creating replacement '
"session %s instead of reusing",
to_id,
session_id,
)
new_session = ServerSession(
transport_id,
to_id,
session_id,
websocket,
self._transport_options,
self._handlers,
close_session_callback=self._delete_session,
)
else:
# If the instance id is the same, we reuse the session and assign
# a new websocket to it.
logger.debug(
'Reuse old session with "%s" using new ws: %s',
to_id,
websocket.id,
)
try:
await old_session.replace_with_new_websocket(websocket)
new_session = old_session
except FailedSendingMessageException as e:
raise e

self._sessions[new_session._to_id] = new_session

Expand Down Expand Up @@ -311,5 +328,6 @@ async def _establish_handshake(

async def _delete_session(self, session: Session) -> None:
async with self._session_lock:
if session._to_id in self._sessions:
existing_session = self._sessions.get(session._to_id)
if existing_session is session:
del self._sessions[session._to_id]
15 changes: 12 additions & 3 deletions src/replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,17 @@ async def close_websocket(
if should_retry and self._retry_connection_callback:
self._task_manager.create_task(self._retry_connection_callback())

def _should_abort_streams_after_transport_failure(self) -> bool:
return self._retry_connection_callback is None

def _abort_all_streams(self) -> None:
"""Close all active stream channels, notifying any waiting consumers."""
if not self._streams:
return
for stream in self._streams.values():
stream.close()
self._streams.clear()

async def close(self) -> None:
"""Close the session and all associated streams."""
logger.info(
Expand All @@ -310,9 +321,7 @@ async def close(self) -> None:

# TODO: unexpected_close should close stream differently here to
# throw exception correctly.
for stream in self._streams.values():
stream.close()
self._streams.clear()
self._abort_all_streams()

self._state = SessionState.CLOSED

Expand Down
Loading