Skip to content

Commit 248cdf1

Browse files
committed
Fix 2
1 parent 20bfac9 commit 248cdf1

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

src/replit_river/client_session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,17 @@ async def serve(self) -> None:
9999
try:
100100
await self._handle_messages_from_ws()
101101
except ConnectionClosed:
102-
self._abort_all_streams()
102+
if self._should_abort_streams_after_transport_failure():
103+
self._abort_all_streams()
103104
if self._retry_connection_callback:
104105
self._task_manager.create_task(self._retry_connection_callback())
105106

106107
await self._begin_close_session_countdown()
107108
logger.debug("ConnectionClosed while serving", exc_info=True)
108109
except FailedSendingMessageException:
109110
# Expected error if the connection is closed.
110-
self._abort_all_streams()
111+
if self._should_abort_streams_after_transport_failure():
112+
self._abort_all_streams()
111113
logger.debug(
112114
"FailedSendingMessageException while serving", exc_info=True
113115
)

src/replit_river/server_session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ async def serve(self) -> None:
8787
try:
8888
await self._handle_messages_from_ws(tg)
8989
except ConnectionClosed:
90-
self._abort_all_streams()
90+
if self._should_abort_streams_after_transport_failure():
91+
self._abort_all_streams()
9192
if self._retry_connection_callback:
9293
self._task_manager.create_task(
9394
self._retry_connection_callback()
@@ -97,7 +98,8 @@ async def serve(self) -> None:
9798
logger.debug("ConnectionClosed while serving", exc_info=True)
9899
except FailedSendingMessageException:
99100
# Expected error if the connection is closed.
100-
self._abort_all_streams()
101+
if self._should_abort_streams_after_transport_failure():
102+
self._abort_all_streams()
101103
logger.debug(
102104
"FailedSendingMessageException while serving", exc_info=True
103105
)

src/replit_river/server_transport.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,35 @@ async def _get_or_create_session(
153153
close_session_callback=self._delete_session,
154154
)
155155
else:
156-
# If the instance id is the same, we reuse the session and assign
157-
# a new websocket to it.
158-
logger.debug(
159-
'Reuse old session with "%s" using new ws: %s',
160-
to_id,
161-
websocket.id,
162-
)
163-
try:
164-
await old_session.replace_with_new_websocket(websocket)
165-
new_session = old_session
166-
except FailedSendingMessageException as e:
167-
raise e
156+
if not await old_session.is_session_open():
157+
logger.info(
158+
'Session "%s" is not active, creating replacement '
159+
"session %s instead of reusing",
160+
to_id,
161+
session_id,
162+
)
163+
new_session = ServerSession(
164+
transport_id,
165+
to_id,
166+
session_id,
167+
websocket,
168+
self._transport_options,
169+
self._handlers,
170+
close_session_callback=self._delete_session,
171+
)
172+
else:
173+
# If the instance id is the same, we reuse the session and assign
174+
# a new websocket to it.
175+
logger.debug(
176+
'Reuse old session with "%s" using new ws: %s',
177+
to_id,
178+
websocket.id,
179+
)
180+
try:
181+
await old_session.replace_with_new_websocket(websocket)
182+
new_session = old_session
183+
except FailedSendingMessageException as e:
184+
raise e
168185

169186
self._sessions[new_session._to_id] = new_session
170187

@@ -311,5 +328,6 @@ async def _establish_handshake(
311328

312329
async def _delete_session(self, session: Session) -> None:
313330
async with self._session_lock:
314-
if session._to_id in self._sessions:
331+
existing_session = self._sessions.get(session._to_id)
332+
if existing_session is session:
315333
del self._sessions[session._to_id]

src/replit_river/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ async def close_websocket(
286286
if should_retry and self._retry_connection_callback:
287287
self._task_manager.create_task(self._retry_connection_callback())
288288

289+
def _should_abort_streams_after_transport_failure(self) -> bool:
290+
return self._retry_connection_callback is None
291+
289292
def _abort_all_streams(self) -> None:
290293
"""Close all active stream channels, notifying any waiting consumers."""
291294
if not self._streams:

0 commit comments

Comments
 (0)