@@ -268,14 +268,17 @@ def close_session(reason: Exception | None) -> None:
268268 self .close (reason , current_state = current_state ),
269269 )
270270
271- def transition_connecting () -> None :
271+ def transition_connecting (ws : ClientConnection ) -> None :
272272 if self ._state in TerminalStates :
273273 return
274274 logger .debug ("transition_connecting" )
275275 self ._state = SessionState .CONNECTING
276276 # "Clear" here means observers should wait until we are connected.
277277 self ._wait_for_connected .clear ()
278278
279+ # Expose the current ws to be collected by close()
280+ self ._ws = ws
281+
279282 def transition_connected (ws : ClientConnection ) -> None :
280283 if self ._state in TerminalStates :
281284 return
@@ -1108,7 +1111,7 @@ async def _do_ensure_connected[HandshakeMetadata](
11081111 get_next_sent_seq : Callable [[], int ],
11091112 get_current_ack : Callable [[], int ],
11101113 get_state : Callable [[], SessionState ],
1111- transition_connecting : Callable [[], None ],
1114+ transition_connecting : Callable [[ClientConnection ], None ],
11121115 close_ws_in_background : Callable [[ClientConnection ], None ],
11131116 transition_connected : Callable [[ClientConnection ], None ],
11141117 unbind_connecting_task : Callable [[], None ],
@@ -1128,12 +1131,12 @@ async def _do_ensure_connected[HandshakeMetadata](
11281131 attempt_count += 1
11291132
11301133 rate_limiter .consume_budget (client_id )
1131- transition_connecting ()
11321134
11331135 ws : ClientConnection | None = None
11341136 try :
11351137 uri_and_metadata = await uri_and_metadata_factory ()
11361138 ws = await websockets .asyncio .client .connect (uri_and_metadata ["uri" ])
1139+ transition_connecting (ws )
11371140
11381141 try :
11391142 handshake_request = ControlMessageHandshakeRequest [HandshakeMetadata ](
0 commit comments