@@ -265,7 +265,11 @@ def close_session(reason: Exception | None) -> None:
265265 # during the cleanup procedure.
266266
267267 self ._terminating_task = asyncio .create_task (
268- self .close (reason , current_state = current_state ),
268+ self .close (
269+ reason ,
270+ current_state = current_state ,
271+ _wait_for_closed = False ,
272+ ),
269273 )
270274
271275 def transition_connecting (ws : ClientConnection ) -> None :
@@ -405,12 +409,18 @@ async def _enqueue_message(
405409 self ._process_messages .set ()
406410
407411 async def close (
408- self , reason : Exception | None = None , current_state : SessionState | None = None
412+ self ,
413+ reason : Exception | None = None ,
414+ current_state : SessionState | None = None ,
415+ _wait_for_closed : bool = True ,
409416 ) -> None :
410417 """Close the session and all associated streams."""
411418 if (current_state or self ._state ) in TerminalStates :
412419 start = datetime .now ()
413- while (current_state or self ._state ) != SessionState .CLOSED :
420+ while (
421+ _wait_for_closed
422+ and (current_state or self ._state ) != SessionState .CLOSED
423+ ):
414424 elapsed = (datetime .now () - start ).total_seconds ()
415425 if elapsed >= SESSION_CLOSE_TIMEOUT_SEC :
416426 logger .warning (
@@ -632,7 +642,7 @@ async def block_until_connected() -> None:
632642 get_state = lambda : self ._state ,
633643 get_ws = lambda : self ._ws ,
634644 transition_no_connection = transition_no_connection ,
635- close_session = self .close ,
645+ close_session = lambda err : self .close ( err , _wait_for_closed = False ) ,
636646 assert_incoming_seq_bookkeeping = assert_incoming_seq_bookkeeping ,
637647 get_stream = lambda stream_id : self ._streams .get (stream_id ),
638648 enqueue_message = self ._enqueue_message ,
0 commit comments