3939 RiverError ,
4040 RiverException ,
4141 RiverServiceException ,
42+ SessionClosedRiverServiceException ,
4243 StreamClosedRiverServiceException ,
4344 exception_from_message ,
4445)
@@ -154,8 +155,10 @@ def __init__(
154155
155156 # message state
156157 self ._message_enqueued = asyncio .Semaphore ()
157- self ._space_available_cond = asyncio .Condition ()
158- self ._queue_full_lock = asyncio .Lock ()
158+ self ._space_available = asyncio .Event ()
159+ # Ensure we initialize the above Event to "set" to avoid being blocked from
160+ # the beginning.
161+ self ._space_available .set ()
159162
160163 # stream for tasks
161164 self ._streams : dict [str , Channel [Any ]] = {}
@@ -337,22 +340,24 @@ async def _send_message(
337340 with use_span (span ):
338341 trace_propagator .inject (msg , None , trace_setter )
339342
340- # As we prepare to push onto the buffer, if the buffer is full, we lock.
341- # This lock will be released by the buffered_message_sender task, so it's
342- # important that we don't release it here.
343- #
344- # The reason for this is that in Python, asyncio.Lock is "fair", first
345- # come, first served.
346- #
347- # If somebody else is already waiting or we've filled the buffer, we
348- # should get in line.
349- if (
350- self ._queue_full_lock .locked ()
351- or len (self ._send_buffer ) >= self ._transport_options .buffer_size
352- ):
353- logger .debug ("_send_message: queue full, waiting" )
354- await self ._queue_full_lock .acquire ()
343+ # Ensure the buffer isn't full before we enqueue
344+ await self ._space_available .wait ()
345+
346+ # Before we append, do an important check
347+ if self ._state in TerminalStates :
348+ # session is closing / closed, raise
349+ raise SessionClosedRiverServiceException (
350+ "river session is closed, dropping message" ,
351+ service_name ,
352+ procedure_name ,
353+ )
354+
355355 self ._send_buffer .append (msg )
356+
357+ # If the buffer is now full, reset the block
358+ if len (self ._send_buffer ) >= self ._transport_options .buffer_size :
359+ self ._space_available .clear ()
360+
356361 # Wake up buffered_message_sender
357362 self ._message_enqueued .release ()
358363 self .seq += 1
@@ -368,7 +373,10 @@ async def close(self) -> None:
368373 self ._state = SessionState .CLOSING
369374
370375 # We need to wake up all tasks waiting for connection to be established
371- self ._wait_for_connected .clear ()
376+ self ._wait_for_connected .set ()
377+
378+ # We also need to wake up consumers waiting to enqueue messages
379+ self ._space_available .set ()
372380
373381 await self ._task_manager .cancel_all_tasks ()
374382
@@ -399,8 +407,7 @@ def commit(msg: TransportMessage) -> None:
399407 self ._ack_buffer .append (pending )
400408
401409 # On commit, release pending writers waiting for more buffer space
402- if self ._queue_full_lock .locked ():
403- self ._queue_full_lock .release ()
410+ self ._space_available .set ()
404411
405412 def get_next_pending () -> TransportMessage | None :
406413 if self ._send_buffer :
@@ -1157,7 +1164,9 @@ async def _serve(
11571164 while our_task and not our_task .cancelling () and not our_task .cancelled ():
11581165 logger .debug (f"_serve loop count={ idx } " )
11591166 idx += 1
1160- while (ws := get_ws ()) is None or (state := get_state ()) in ConnectingStates :
1167+ while (ws := get_ws ()) is None or (
1168+ state := get_state ()
1169+ ) in ConnectingStates :
11611170 logger .debug ("_handle_messages_from_ws spinning while connecting" )
11621171 await block_until_connected ()
11631172
0 commit comments