@@ -121,6 +121,9 @@ class Session:
121121 ack : int # Most recently acknowledged seq
122122 seq : int # Last sent sequence number
123123
124+ # Terminating
125+ _terminating_task : asyncio .Task [None ]
126+
124127 def __init__ (
125128 self ,
126129 transport_id : str ,
@@ -188,13 +191,20 @@ async def ensure_connected[HandshakeMetadata](
188191 if self .is_connected ():
189192 return
190193
194+ def do_close () -> None :
195+ # We can't just call self.close() directly because
196+ # we're inside a thread that will eventually be awaited
197+ # during the cleanup procedure.
198+ self ._terminating_task = asyncio .create_task (self .close ())
199+
191200 if not self ._connecting_task :
192201 self ._connecting_task = asyncio .create_task (
193202 self ._do_ensure_connected (
194203 client_id ,
195204 rate_limiter ,
196205 uri_and_metadata_factory ,
197206 protocol_version ,
207+ do_close ,
198208 )
199209 )
200210
@@ -208,6 +218,7 @@ async def _do_ensure_connected[HandshakeMetadata](
208218 [], Awaitable [UriAndMetadata [HandshakeMetadata ]]
209219 ], # noqa: E501
210220 protocol_version : str ,
221+ do_close : Callable [[], None ],
211222 ) -> Literal [True ]:
212223 max_retry = self ._transport_options .connection_retry_options .max_retry
213224 logger .info ("Attempting to establish new ws connection" )
@@ -329,7 +340,8 @@ async def websocket_closed_callback() -> None:
329340 handshake_response .status .code
330341 == ERROR_CODE_SESSION_STATE_MISMATCH
331342 ):
332- await self .close ()
343+ do_close ()
344+
333345 raise RiverException (
334346 ERROR_HANDSHAKE ,
335347 f"Handshake failed with code { handshake_response .status .code } : {
@@ -553,14 +565,20 @@ async def block_until_connected() -> None:
553565 )
554566
555567 def _start_close_session_checker (self ) -> None :
568+ def do_close () -> None :
569+ # We can't just call self.close() directly because
570+ # we're inside a thread that will eventually be awaited
571+ # during the cleanup procedure.
572+ self ._terminating_task = asyncio .create_task (self .close ())
573+
556574 self ._task_manager .create_task (
557575 _check_to_close_session (
558576 self ._transport_id ,
559577 self ._transport_options .close_session_check_interval_ms ,
560578 lambda : self ._state ,
561579 self ._get_current_time ,
562580 lambda : self ._close_session_after_time_secs ,
563- self . close ,
581+ do_close = do_close ,
564582 )
565583 )
566584
@@ -986,7 +1004,7 @@ async def _check_to_close_session(
9861004 get_state : Callable [[], SessionState ],
9871005 get_current_time : Callable [[], Awaitable [float ]],
9881006 get_close_session_after_time_secs : Callable [[], float | None ],
989- do_close : Callable [[], Awaitable [ None ] ],
1007+ do_close : Callable [[], None ],
9901008) -> None :
9911009 our_task = asyncio .current_task ()
9921010 while our_task and not our_task .cancelling () and not our_task .cancelled ():
@@ -1003,8 +1021,8 @@ async def _check_to_close_session(
10031021 continue
10041022 if current_time > close_session_after_time_secs :
10051023 logger .info ("Grace period ended for %s, closing session" , transport_id )
1006- await do_close ()
1007- return
1024+ do_close ()
1025+ our_task . cancel ()
10081026
10091027
10101028async def _buffered_message_sender (
0 commit comments