@@ -114,6 +114,10 @@ class Session:
114114 _heartbeat_misses : int
115115 _retry_connection_callback : RetryConnectionCallback | None
116116
117+ # message state
118+ _process_messages : asyncio .Event
119+ _space_available : asyncio .Event
120+
117121 # stream for tasks
118122 _streams : dict [str , Channel [Any ]]
119123
@@ -154,7 +158,7 @@ def __init__(
154158 self ._retry_connection_callback = retry_connection_callback
155159
156160 # message state
157- self ._message_enqueued = asyncio .Semaphore ()
161+ self ._process_messages = asyncio .Event ()
158162 self ._space_available = asyncio .Event ()
159163 # Ensure we initialize the above Event to "set" to avoid being blocked from
160164 # the beginning.
@@ -359,7 +363,7 @@ async def _send_message(
359363 self ._space_available .clear ()
360364
361365 # Wake up buffered_message_sender
362- self ._message_enqueued . release ()
366+ self ._process_messages . set ()
363367 self .seq += 1
364368
365369 async def close (self ) -> None :
@@ -372,11 +376,13 @@ async def close(self) -> None:
372376 return
373377 self ._state = SessionState .CLOSING
374378
375- # We need to wake up all tasks waiting for connection to be established
379+ # We're closing, so we need to wake up...
380+ # ... tasks waiting for connection to be established
376381 self ._wait_for_connected .set ()
377-
378- # We also need to wake up consumers waiting to enqueue messages
382+ # ... consumers waiting to enqueue messages
379383 self ._space_available .set ()
384+ # ... message processor so it can exit cleanly
385+ self ._process_messages .set ()
380386
381387 await self ._task_manager .cancel_all_tasks ()
382388
@@ -406,8 +412,12 @@ def commit(msg: TransportMessage) -> None:
406412 logger .error ("Out of sequence error" )
407413 self ._ack_buffer .append (pending )
408414
409- # On commit, release pending writers waiting for more buffer space
415+ # On commit...
416+ # ... release pending writers waiting for more buffer space
410417 self ._space_available .set ()
418+ # ... tell the message sender to back off if there are no pending messages
419+ if not self ._send_buffer :
420+ self ._process_messages .clear ()
411421
412422 def get_next_pending () -> TransportMessage | None :
413423 if self ._send_buffer :
@@ -422,10 +432,13 @@ def get_ws() -> ClientConnection | None:
422432 async def block_until_connected () -> None :
423433 await self ._wait_for_connected .wait ()
424434
435+ async def block_until_message_available () -> None :
436+ await self ._process_messages .wait ()
437+
425438 self ._task_manager .create_task (
426439 _buffered_message_sender (
427440 block_until_connected = block_until_connected ,
428- message_enqueued = self . _message_enqueued ,
441+ block_until_message_available = block_until_message_available ,
429442 get_ws = get_ws ,
430443 websocket_closed_callback = self ._begin_close_session_countdown ,
431444 get_next_pending = get_next_pending ,
@@ -865,7 +878,7 @@ async def send_close_stream(
865878
866879async def _buffered_message_sender (
867880 block_until_connected : Callable [[], Awaitable [None ]],
868- message_enqueued : asyncio . Semaphore ,
881+ block_until_message_available : Callable [[], Awaitable [ None ]] ,
869882 get_ws : Callable [[], ClientConnection | None ],
870883 websocket_closed_callback : Callable [[], Coroutine [Any , Any , None ]],
871884 get_next_pending : Callable [[], TransportMessage | None ],
@@ -874,18 +887,19 @@ async def _buffered_message_sender(
874887) -> None :
875888 our_task = asyncio .current_task ()
876889 while our_task and not our_task .cancelling () and not our_task .cancelled ():
877- await message_enqueued .acquire ()
890+ await block_until_message_available ()
891+
892+ if get_state () in TerminalStates :
893+ logger .debug ("buffered_message_sender: closing" )
894+ return
895+
878896 while (ws := get_ws ()) is None :
879897 # Block until we have a handle
880898 logger .debug (
881899 "buffered_message_sender: Waiting until ws is connected" ,
882900 )
883901 await block_until_connected ()
884902
885- if get_state () in TerminalStates :
886- logger .debug ("We're going away!" )
887- return
888-
889903 if not ws :
890904 logger .debug ("ws is not connected, loop" )
891905 continue
@@ -906,18 +920,15 @@ async def _buffered_message_sender(
906920 type (e ),
907921 exc_info = e ,
908922 )
909- message_enqueued .release ()
910923 break
911924 except FailedSendingMessageException :
912925 logger .error (
913926 "Failed sending message, waiting for retry from buffer" ,
914927 exc_info = True ,
915928 )
916- message_enqueued .release ()
917929 break
918930 except Exception :
919931 logger .exception ("Error attempting to send buffered messages" )
920- message_enqueued .release ()
921932 break
922933
923934
0 commit comments