Skip to content

Commit bb52361

Browse files
Avoid deadlocking client if streams don't clean up after themselves
1 parent 8208d2b commit bb52361

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

src/replit_river/transport_options.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class TransportOptions(BaseModel):
2727
connection_retry_options: ConnectionRetryOptions = ConnectionRetryOptions()
2828
buffer_size: int = 1_000
2929
transparent_reconnect: bool = True
30+
shutdown_all_streams_timeout_ms: float = 10_000
3031

3132
def websocket_disconnect_grace_ms(self) -> float:
3233
return self.heartbeat_ms * self.heartbeats_until_dead
@@ -39,11 +40,16 @@ def create_from_env(cls) -> "TransportOptions":
3940
)
4041
heartbeat_ms = float(os.getenv("HEARTBEAT_MS", 2_000))
4142
heartbeats_to_dead = int(os.getenv("HEARTBEATS_UNTIL_DEAD", 2))
43+
shutdown_all_streams_timeout_ms = float(
44+
os.getenv("SHUTDOWN_STREAMS_TIMEOUT_MS", 10_000)
45+
)
46+
4247
return TransportOptions(
4348
handshake_timeout_ms=handshake_timeout_ms,
4449
session_disconnect_grace_ms=session_disconnect_grace_ms,
4550
heartbeat_ms=heartbeat_ms,
4651
heartbeats_until_dead=heartbeats_to_dead,
52+
shutdown_all_streams_timeout_ms=shutdown_all_streams_timeout_ms,
4753
)
4854

4955

src/replit_river/v2/session.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ def close_session(reason: Exception | None) -> None:
262262
# during the cleanup procedure.
263263

264264
self._terminating_task = asyncio.create_task(
265-
self.close(reason, current_state=current_state),
266-
)
265+
self.close(reason, current_state=current_state),
266+
)
267267

268268
def transition_connecting() -> None:
269269
if self._state in TerminalStates:
@@ -396,7 +396,9 @@ async def _enqueue_message(
396396
# Wake up buffered_message_sender
397397
self._process_messages.set()
398398

399-
async def close(self, reason: Exception | None = None, current_state: SessionState | None = None ) -> None:
399+
async def close(
400+
self, reason: Exception | None = None, current_state: SessionState | None = None
401+
) -> None:
400402
"""Close the session and all associated streams."""
401403
logger.info(
402404
f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}"
@@ -435,9 +437,28 @@ async def close(self, reason: Exception | None = None, current_state: SessionSta
435437
)
436438
stream_meta["release_backpressured_waiter"]()
437439
# Before we GC the streams, let's wait for all tasks to be closed gracefully.
438-
await asyncio.gather(
439-
*[stream_meta["output"].join() for stream_meta in self._streams.values()]
440-
)
440+
try:
441+
async with asyncio.timeout(
442+
self._transport_options.shutdown_all_streams_timeout_ms
443+
):
444+
# Block for backpressure and emission errors from the ws
445+
await asyncio.gather(
446+
*[
447+
stream_meta["output"].join()
448+
for stream_meta in self._streams.values()
449+
]
450+
)
451+
except asyncio.TimeoutError:
452+
spans: list[Span] = [
453+
stream_meta["span"]
454+
for stream_meta in self._streams.values()
455+
if not stream_meta["output"].closed()
456+
]
457+
span_ids = [span.get_span_context().span_id for span in spans]
458+
logger.exception(
459+
"Timeout waiting for output streams to finallize",
460+
extra={"span_ids": span_ids},
461+
)
441462
self._streams.clear()
442463

443464
if self._ws:

0 commit comments

Comments
 (0)