Skip to content

Commit 20bfac9

Browse files
committed
Abort streams
1 parent dcec28b commit 20bfac9

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

src/replit_river/client_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,20 @@ async def serve(self) -> None:
9999
try:
100100
await self._handle_messages_from_ws()
101101
except ConnectionClosed:
102+
self._abort_all_streams()
102103
if self._retry_connection_callback:
103104
self._task_manager.create_task(self._retry_connection_callback())
104105

105106
await self._begin_close_session_countdown()
106107
logger.debug("ConnectionClosed while serving", exc_info=True)
107108
except FailedSendingMessageException:
108109
# Expected error if the connection is closed.
110+
self._abort_all_streams()
109111
logger.debug(
110112
"FailedSendingMessageException while serving", exc_info=True
111113
)
112114
except Exception:
115+
self._abort_all_streams()
113116
logger.exception("caught exception at message iterator")
114117
except ExceptionGroup as eg:
115118
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))

src/replit_river/server_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ async def serve(self) -> None:
8787
try:
8888
await self._handle_messages_from_ws(tg)
8989
except ConnectionClosed:
90+
self._abort_all_streams()
9091
if self._retry_connection_callback:
9192
self._task_manager.create_task(
9293
self._retry_connection_callback()
@@ -96,10 +97,12 @@ async def serve(self) -> None:
9697
logger.debug("ConnectionClosed while serving", exc_info=True)
9798
except FailedSendingMessageException:
9899
# Expected error if the connection is closed.
100+
self._abort_all_streams()
99101
logger.debug(
100102
"FailedSendingMessageException while serving", exc_info=True
101103
)
102104
except Exception:
105+
self._abort_all_streams()
103106
logger.exception("caught exception at message iterator")
104107
except ExceptionGroup as eg:
105108
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))

src/replit_river/session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,14 @@ async def close_websocket(
286286
if should_retry and self._retry_connection_callback:
287287
self._task_manager.create_task(self._retry_connection_callback())
288288

289+
def _abort_all_streams(self) -> None:
290+
"""Close all active stream channels, notifying any waiting consumers."""
291+
if not self._streams:
292+
return
293+
for stream in self._streams.values():
294+
stream.close()
295+
self._streams.clear()
296+
289297
async def close(self) -> None:
290298
"""Close the session and all associated streams."""
291299
logger.info(
@@ -310,9 +318,7 @@ async def close(self) -> None:
310318

311319
# TODO: unexpected_close should close stream differently here to
312320
# throw exception correctly.
313-
for stream in self._streams.values():
314-
stream.close()
315-
self._streams.clear()
321+
self._abort_all_streams()
316322

317323
self._state = SessionState.CLOSED
318324

0 commit comments

Comments
 (0)