@@ -252,28 +252,6 @@ def get_next_sent_seq() -> int:
252252 return self ._send_buffer [0 ].seq
253253 return self .seq
254254
255- def close_session (reason : Exception | None ) -> None :
256- # If we're already closing, just let whoever's currently doing it handle it.
257- if self ._state in TerminalStates :
258- return
259-
260- # Avoid closing twice
261- if self ._terminating_task is None :
262- current_state = self ._state
263- self ._state = SessionState .CLOSING
264-
265- # We can't just call self.close() directly because
266- # we're inside a thread that will eventually be awaited
267- # during the cleanup procedure.
268-
269- self ._terminating_task = asyncio .create_task (
270- self .close (
271- reason ,
272- current_state = current_state ,
273- _wait_for_closed = False ,
274- ),
275- )
276-
277255 def transition_connecting (ws : ClientConnection ) -> None :
278256 if self ._state in TerminalStates :
279257 return
@@ -328,7 +306,7 @@ def unbind_connecting_task() -> None:
328306 close_ws_in_background = close_ws_in_background ,
329307 transition_connected = transition_connected ,
330308 unbind_connecting_task = unbind_connecting_task ,
331- close_session = close_session ,
309+ close_session = self . _close_internal_nowait ,
332310 )
333311 )
334312
@@ -413,14 +391,9 @@ async def _enqueue_message(
413391 async def close (
414392 self ,
415393 reason : Exception | None = None ,
416- current_state : SessionState | None = None ,
417- _wait_for_closed : bool = True ,
418394 ) -> None :
419395 """Close the session and all associated streams."""
420396 if self ._closing_waiter :
421- # Break early for internal callers
422- if not _wait_for_closed :
423- return
424397 try :
425398 logger .debug ("Session already closing, waiting..." )
426399 async with asyncio .timeout (SESSION_CLOSE_TIMEOUT_SEC ):
@@ -431,79 +404,112 @@ async def close(
431404 "seconds to close, leaking" ,
432405 )
433406 return
434- logger .info (
435- f"{ self .session_id } closing session to { self ._server_id } , ws: { self ._ws } "
436- )
437- self ._state = SessionState .CLOSING
438- self ._closing_waiter = asyncio .Event ()
407+ await self ._close_internal (reason )
439408
440- # We're closing, so we need to wake up...
441- # ... tasks waiting for connection to be established
442- self ._wait_for_connected .set ()
443- # ... consumers waiting to enqueue messages
444- self ._space_available .set ()
445- # ... message processor so it can exit cleanly
446- self ._process_messages .set ()
409+ def _close_internal_nowait (self , reason : Exception | None = None ) -> None :
410+ """
411+ When calling close() from asyncio Tasks, we must not block.
412+
413+ This function does so, deferring to the underlying infrastructure for
414+ creating self._terminating_task.
415+ """
416+ self ._close_internal (reason )
417+
418+ def _close_internal (self , reason : Exception | None = None ) -> asyncio .Task [None ]:
419+ """
420+ Internal close method. Subsequent calls past the first do not block.
421+
422+ This is intended to be the primary driver of a session being torn down
423+ and returned to its initial state.
424+
425+ NB: This function is intended to be the sole lifecycle manager of
426+ self._terminating_task. Waiting on the completion of that task is optional,
427+ but the population of that property is critical.
428+
429+ NB: We must not await the task returned from this function from chained tasks
430+ inside this session, otherwise we will create a thread loop.
431+ """
432+
433+ async def do_close () -> None :
434+ logger .info (
435+ f"{ self .session_id } closing session to { self ._server_id } , "
436+ f"ws: { self ._ws } "
437+ )
438+ self ._state = SessionState .CLOSING
439+ self ._closing_waiter = asyncio .Event ()
447440
448- # Wait to permit the waiting tasks to shut down gracefully
449- await asyncio .sleep (0.25 )
441+ # We're closing, so we need to wake up...
442+ # ... tasks waiting for connection to be established
443+ self ._wait_for_connected .set ()
444+ # ... consumers waiting to enqueue messages
445+ self ._space_available .set ()
446+ # ... message processor so it can exit cleanly
447+ self ._process_messages .set ()
448+
449+ # Wait to permit the waiting tasks to shut down gracefully
450+ await asyncio .sleep (0.25 )
450451
451- await self ._task_manager .cancel_all_tasks ()
452+ await self ._task_manager .cancel_all_tasks ()
452453
453- for stream_meta in self ._streams .values ():
454- stream_meta ["output" ].close ()
455- # Wake up backpressured writers
454+ for stream_meta in self ._streams .values ():
455+ stream_meta ["output" ].close ()
456+ # Wake up backpressured writers
457+ try :
458+ stream_meta ["error_channel" ].put_nowait (
459+ reason
460+ or SessionClosedRiverServiceException (
461+ "river session is closed" ,
462+ )
463+ )
464+ except ChannelFull :
465+ logger .exception (
466+ "Unable to tell the caller that the session is going away" ,
467+ )
468+ stream_meta ["release_backpressured_waiter" ]()
469+ # Before we GC the streams, let's wait for all tasks to be closed gracefully
456470 try :
457- stream_meta ["error_channel" ].put_nowait (
458- reason
459- or SessionClosedRiverServiceException (
460- "river session is closed" ,
471+ async with asyncio .timeout (
472+ self ._transport_options .shutdown_all_streams_timeout_ms
473+ ):
474+ # Block for backpressure and emission errors from the ws
475+ await asyncio .gather (
476+ * [
477+ stream_meta ["output" ].join ()
478+ for stream_meta in self ._streams .values ()
479+ ]
461480 )
462- )
463- except ChannelFull :
481+ except asyncio .TimeoutError :
482+ spans : list [Span ] = [
483+ stream_meta ["span" ]
484+ for stream_meta in self ._streams .values ()
485+ if not stream_meta ["output" ].closed ()
486+ ]
487+ span_ids = [span .get_span_context ().span_id for span in spans ]
464488 logger .exception (
465- "Unable to tell the caller that the session is going away" ,
489+ "Timeout waiting for output streams to finallize" ,
490+ extra = {"span_ids" : span_ids },
466491 )
467- stream_meta ["release_backpressured_waiter" ]()
468- # Before we GC the streams, let's wait for all tasks to be closed gracefully.
469- try :
470- async with asyncio .timeout (
471- self ._transport_options .shutdown_all_streams_timeout_ms
472- ):
473- # Block for backpressure and emission errors from the ws
474- await asyncio .gather (
475- * [
476- stream_meta ["output" ].join ()
477- for stream_meta in self ._streams .values ()
478- ]
479- )
480- except asyncio .TimeoutError :
481- spans : list [Span ] = [
482- stream_meta ["span" ]
483- for stream_meta in self ._streams .values ()
484- if not stream_meta ["output" ].closed ()
485- ]
486- span_ids = [span .get_span_context ().span_id for span in spans ]
487- logger .exception (
488- "Timeout waiting for output streams to finallize" ,
489- extra = {"span_ids" : span_ids },
490- )
491- self ._streams .clear ()
492+ self ._streams .clear ()
492493
493- if self ._ws :
494- # The Session isn't guaranteed to live much longer than this close()
495- # invocation, so let's await this close to avoid dropping the socket.
496- await self ._ws .close ()
494+ if self ._ws :
495+ # The Session isn't guaranteed to live much longer than this close()
496+ # invocation, so let's await this close to avoid dropping the socket.
497+ await self ._ws .close ()
497498
498- self ._state = SessionState .CLOSED
499+ self ._state = SessionState .CLOSED
499500
500- # Clear the session in transports
501- # This will get us GC'd, so this should be the last thing.
502- self ._close_session_callback (self )
501+ # Clear the session in transports
502+ # This will get us GC'd, so this should be the last thing.
503+ self ._close_session_callback (self )
503504
504- # Release waiters, then release the event
505- self ._closing_waiter .set ()
506- self ._closing_waiter = None
505+ # Release waiters, then release the event
506+ self ._closing_waiter .set ()
507+ self ._closing_waiter = None
508+
509+ if self ._terminating_task :
510+ return self ._terminating_task
511+
512+ return asyncio .create_task (do_close ())
507513
508514 def _start_buffered_message_sender (
509515 self ,
@@ -646,7 +652,7 @@ async def block_until_connected() -> None:
646652 get_state = lambda : self ._state ,
647653 get_ws = lambda : self ._ws ,
648654 transition_no_connection = transition_no_connection ,
649- close_session = lambda err : self .close ( err , _wait_for_closed = False ) ,
655+ close_session = self ._close_internal_nowait ,
650656 assert_incoming_seq_bookkeeping = assert_incoming_seq_bookkeeping ,
651657 get_stream = lambda stream_id : self ._streams .get (stream_id ),
652658 enqueue_message = self ._enqueue_message ,
@@ -1137,8 +1143,11 @@ async def websocket_closed_callback() -> None:
11371143
11381144 try :
11391145 data = await ws .recv (decode = False )
1140- except ConnectionClosedOK as e :
1141- close_session (e )
1146+ except ConnectionClosedOK :
1147+ # In the case of a normal connection closure, we defer to
1148+ # the outer loop to determine next steps.
1149+ # A call to close(...) should set the SessionState to a terminal one,
1150+ # otherwise we should try again.
11421151 continue
11431152 except ConnectionClosed as e :
11441153 logger .debug (
@@ -1226,7 +1235,7 @@ async def _recv_from_ws(
12261235 get_state : Callable [[], SessionState ],
12271236 get_ws : Callable [[], ClientConnection | None ],
12281237 transition_no_connection : Callable [[], Awaitable [None ]],
1229- close_session : Callable [[Exception | None ], Awaitable [ None ] ],
1238+ close_session : Callable [[Exception | None ], None ],
12301239 assert_incoming_seq_bookkeeping : Callable [
12311240 [str , int , int ], Literal [True ] | _IgnoreMessage
12321241 ],
@@ -1361,7 +1370,7 @@ async def _recv_from_ws(
13611370 stream_meta ["output" ].close ()
13621371 except OutOfOrderMessageException :
13631372 logger .exception ("Out of order message, closing connection" )
1364- await close_session (
1373+ close_session (
13651374 SessionClosedRiverServiceException (
13661375 "Out of order message, closing connection"
13671376 )
@@ -1371,7 +1380,7 @@ async def _recv_from_ws(
13711380 logger .exception (
13721381 "Got invalid transport message, closing session" ,
13731382 )
1374- await close_session (
1383+ close_session (
13751384 SessionClosedRiverServiceException (
13761385 "Out of order message, closing connection"
13771386 )
0 commit comments