@@ -100,6 +100,7 @@ class Session:
100100 _close_session_callback : CloseSessionCallback
101101 _close_session_after_time_secs : float | None
102102 _connecting_task : asyncio .Task [Literal [True ]] | None
103+ _connection_condition : asyncio .Condition
103104
104105 # ws state
105106 _ws_unwrapped : ClientConnection | None
@@ -135,6 +136,7 @@ def __init__(
135136 self ._close_session_callback = close_session_callback
136137 self ._close_session_after_time_secs : float | None = None
137138 self ._connecting_task = None
139+ self ._connection_condition = asyncio .Condition ()
138140
139141 # ws state
140142 self ._ws_unwrapped = None
@@ -162,11 +164,13 @@ async def do_close_websocket() -> None:
162164 self ._state ,
163165 self ._ws_unwrapped ,
164166 )
165- self ._state = SessionState .CLOSING
166167 if self ._ws_unwrapped :
167168 self ._task_manager .create_task (self ._ws_unwrapped .close ())
168169 if self ._retry_connection_callback :
169170 self ._task_manager .create_task (self ._retry_connection_callback ())
171+ self ._ws_unwrapped = None
172+ else :
173+ self ._state = SessionState .CLOSING
170174 await self ._begin_close_session_countdown ()
171175
172176 def increment_and_get_heartbeat_misses () -> int :
@@ -211,14 +215,18 @@ def get_next_pending() -> TransportMessage | None:
211215 return self ._send_buffer [0 ]
212216 return None
213217
218+ # TODO: Just return _ws_unwrapped once we are no longer using the legacy client
219+ def get_ws () -> WebSocketCommonProtocol | ClientConnection | None :
220+ logger .debug ("get_ws: %r %r" , self .is_connected (), self ._ws_unwrapped )
221+ if self .is_connected ():
222+ return self ._ws_unwrapped
223+ return None
224+
214225 self ._task_manager .create_task (
215226 buffered_message_sender (
227+ self ._connection_condition ,
216228 self ._message_enqueued ,
217- get_ws = lambda : (
218- cast (WebSocketCommonProtocol | ClientConnection , self ._ws_unwrapped )
219- if self .is_connected ()
220- else None
221- ),
229+ get_ws = get_ws ,
222230 websocket_closed_callback = self ._begin_close_session_countdown ,
223231 get_next_pending = get_next_pending ,
224232 commit = commit ,
@@ -242,6 +250,7 @@ async def ensure_connected[HandshakeMetadata](
242250 logic that actually establishes the connection.
243251 """
244252
253+ logger .debug ("ensure_connected: %r" , self .is_connected ())
245254 if self .is_connected ():
246255 return
247256
@@ -255,7 +264,9 @@ async def ensure_connected[HandshakeMetadata](
255264 )
256265 )
257266
267+ logger .debug ("BEFORE await _do_ensure_connected" )
258268 await self ._connecting_task
269+ logger .debug ("AFTER await _do_ensure_connected" )
259270
260271 async def _do_ensure_connected [HandshakeMetadata ](
261272 self ,
@@ -271,6 +282,7 @@ async def _do_ensure_connected[HandshakeMetadata](
271282
272283 last_error : Exception | None = None
273284 i = 0
285+ await self ._connection_condition .acquire ()
274286 while rate_limiter .has_budget_or_throw (client_id , ERROR_HANDSHAKE , last_error ):
275287 if i > 0 :
276288 logger .info (f"Retrying build handshake number { i } times" )
@@ -378,6 +390,11 @@ async def websocket_closed_callback() -> None:
378390 last_error = None
379391 rate_limiter .start_restoring_budget (client_id )
380392 self ._state = SessionState .ACTIVE
393+ self ._ws_unwrapped = ws
394+ logger .debug ("Before notify_all: %r %r %r" , self ._state , self ._ws_unwrapped , self ._connection_condition )
395+ self ._connection_condition .notify_all ()
396+ self ._connection_condition .release ()
397+ break
381398 except RiverException as e :
382399 await ws .close ()
383400 raise e
@@ -411,6 +428,7 @@ async def websocket_closed_callback() -> None:
411428 f"Failed to create ws after retrying { max_retry } number of times" ,
412429 ) from last_error
413430
431+ logger .debug ("EXITING _do_ensure_connected" )
414432 return True
415433
416434 def is_closed (self ) -> bool :
@@ -419,7 +437,7 @@ def is_closed(self) -> bool:
419437 Do not send messages, do not expect any more messages to be emitted,
420438 the state is expected to be stale.
421439 """
422- return self ._state not in TerminalStates
440+ return self ._state in TerminalStates
423441
424442 def is_connected (self ) -> bool :
425443 return self ._state == SessionState .ACTIVE
@@ -477,6 +495,7 @@ async def send_message(
477495 serviceName = service_name ,
478496 procedureName = procedure_name ,
479497 )
498+ logger .debug ("SENDING MESSAGE: %r" , msg )
480499
481500 if span :
482501 with use_span (span ):
@@ -516,17 +535,17 @@ async def close(self) -> None:
516535 self ._reset_session_close_countdown ()
517536 await self ._task_manager .cancel_all_tasks ()
518537
519- if self ._ws_unwrapped :
520- # The Session isn't guaranteed to live much longer than this close()
521- # invocation, so let's await this close to avoid dropping the socket.
522- await self ._ws_unwrapped .close ()
523-
524538 # TODO: unexpected_close should close stream differently here to
525539 # throw exception correctly.
526540 for stream in self ._streams .values ():
527541 stream .close ()
528542 self ._streams .clear ()
529543
544+ if self ._ws_unwrapped :
545+ # The Session isn't guaranteed to live much longer than this close()
546+ # invocation, so let's await this close to avoid dropping the socket.
547+ await self ._ws_unwrapped .close ()
548+
530549 self ._state = SessionState .CLOSED
531550
532551 # Clear the session in transports
0 commit comments