Skip to content

Commit 1d2eede

Browse files
More lifecycle tweaks during shutdown so we exit cleanly.
1 parent b8b97b0 commit 1d2eede

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/replit_river/v2/session.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,10 @@ async def websocket_closed_callback() -> None:
315315
"river client get handshake response : %r", handshake_response
316316
) # noqa: E501
317317
if not handshake_response.status.ok:
318-
if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH:
318+
if (
319+
handshake_response.status.code
320+
== ERROR_CODE_SESSION_STATE_MISMATCH
321+
): # noqa: E501
319322
await self.close()
320323
raise RiverException(
321324
ERROR_HANDSHAKE,
@@ -479,8 +482,8 @@ async def close(self) -> None:
479482
self._state = SessionState.CLOSING
480483

481484
# We need to wake up all tasks waiting for connection to be established
482-
assert not self._connection_condition.locked()
483-
await self._connection_condition.acquire()
485+
if not self._connection_condition.locked():
486+
await self._connection_condition.acquire()
484487
self._connection_condition.notify_all()
485488
self._connection_condition.release()
486489

@@ -490,6 +493,8 @@ async def close(self) -> None:
490493
# throw exception correctly.
491494
for stream in self._streams.values():
492495
stream.close()
496+
# Before we GC the streams, let's wait for all tasks to be closed gracefully.
497+
await asyncio.gather(*[x.join() for x in self._streams.values()])
493498
self._streams.clear()
494499

495500
if self._ws_unwrapped:
@@ -646,7 +651,6 @@ async def block_until_connected() -> None:
646651
async with self._connection_condition:
647652
await self._connection_condition.wait()
648653

649-
650654
self._task_manager.create_task(
651655
_serve(
652656
block_until_connected=block_until_connected,

0 commit comments

Comments
 (0)