@@ -315,7 +315,10 @@ async def websocket_closed_callback() -> None:
315315 "river client get handshake response : %r" , handshake_response
316316 ) # noqa: E501
317317 if not handshake_response .status .ok :
318- if handshake_response .status .code == ERROR_CODE_SESSION_STATE_MISMATCH :
318+ if (
319+ handshake_response .status .code
320+ == ERROR_CODE_SESSION_STATE_MISMATCH
321+ ): # noqa: E501
319322 await self .close ()
320323 raise RiverException (
321324 ERROR_HANDSHAKE ,
@@ -479,8 +482,8 @@ async def close(self) -> None:
479482 self ._state = SessionState .CLOSING
480483
481484 # We need to wake up all tasks waiting for connection to be established
482- assert not self ._connection_condition .locked ()
483- await self ._connection_condition .acquire ()
485+ if not self ._connection_condition .locked ():
486+ await self ._connection_condition .acquire ()
484487 self ._connection_condition .notify_all ()
485488 self ._connection_condition .release ()
486489
@@ -490,6 +493,8 @@ async def close(self) -> None:
490493 # throw exception correctly.
491494 for stream in self ._streams .values ():
492495 stream .close ()
496+ # Before we GC the streams, let's wait for all tasks to be closed gracefully.
497+ await asyncio .gather (* [x .join () for x in self ._streams .values ()])
493498 self ._streams .clear ()
494499
495500 if self ._ws_unwrapped :
@@ -646,7 +651,6 @@ async def block_until_connected() -> None:
646651 async with self ._connection_condition :
647652 await self ._connection_condition .wait ()
648653
649-
650654 self ._task_manager .create_task (
651655 _serve (
652656 block_until_connected = block_until_connected ,
0 commit comments