@@ -268,14 +268,17 @@ def close_session(reason: Exception | None) -> None:
268268 self .close (reason , current_state = current_state ),
269269 )
270270
271- def transition_connecting () -> None :
271+ def transition_connecting (ws : ClientConnection ) -> None :
272272 if self ._state in TerminalStates :
273273 return
274274 logger .debug ("transition_connecting" )
275275 self ._state = SessionState .CONNECTING
276276 # "Clear" here means observers should wait until we are connected.
277277 self ._wait_for_connected .clear ()
278278
279+ # Expose the current ws to be collected by close()
280+ self ._ws = ws
281+
279282 def transition_connected (ws : ClientConnection ) -> None :
280283 if self ._state in TerminalStates :
281284 return
@@ -1043,7 +1046,7 @@ async def _do_ensure_connected[HandshakeMetadata](
10431046 get_next_sent_seq : Callable [[], int ],
10441047 get_current_ack : Callable [[], int ],
10451048 get_state : Callable [[], SessionState ],
1046- transition_connecting : Callable [[], None ],
1049+ transition_connecting : Callable [[ClientConnection ], None ],
10471050 close_ws_in_background : Callable [[ClientConnection ], None ],
10481051 transition_connected : Callable [[ClientConnection ], None ],
10491052 unbind_connecting_task : Callable [[], None ],
@@ -1063,12 +1066,12 @@ async def _do_ensure_connected[HandshakeMetadata](
10631066 attempt_count += 1
10641067
10651068 rate_limiter .consume_budget (client_id )
1066- transition_connecting ()
10671069
10681070 ws : ClientConnection | None = None
10691071 try :
10701072 uri_and_metadata = await uri_and_metadata_factory ()
10711073 ws = await websockets .asyncio .client .connect (uri_and_metadata ["uri" ])
1074+ transition_connecting (ws )
10721075
10731076 try :
10741077 handshake_request = ControlMessageHandshakeRequest [HandshakeMetadata ](
0 commit comments