@@ -405,12 +405,18 @@ async def _enqueue_message(
405405 self ._process_messages .set ()
406406
407407 async def close (
408- self , reason : Exception | None = None , current_state : SessionState | None = None
408+ self ,
409+ reason : Exception | None = None ,
410+ current_state : SessionState | None = None ,
411+ _wait_for_closed : bool = True ,
409412 ) -> None :
410413 """Close the session and all associated streams."""
411414 if (current_state or self ._state ) in TerminalStates :
412415 start = datetime .now ()
413- while (current_state or self ._state ) != SessionState .CLOSED :
416+ while (
417+ _wait_for_closed
418+ and (current_state or self ._state ) != SessionState .CLOSED
419+ ):
414420 elapsed = (datetime .now () - start ).total_seconds ()
415421 if elapsed >= SESSION_CLOSE_TIMEOUT_SEC :
416422 logger .warning (
@@ -632,7 +638,7 @@ async def block_until_connected() -> None:
632638 get_state = lambda : self ._state ,
633639 get_ws = lambda : self ._ws ,
634640 transition_no_connection = transition_no_connection ,
635- close_session = self .close ,
641+ close_session = lambda err : self .close ( err , _wait_for_closed = False ) ,
636642 assert_incoming_seq_bookkeeping = assert_incoming_seq_bookkeeping ,
637643 get_stream = lambda stream_id : self ._streams .get (stream_id ),
638644 enqueue_message = self ._enqueue_message ,
0 commit comments