Skip to content

Commit 39e6f2f

Browse files
Expose in-flight ws to the session state to permit close()
1 parent ea5f5cd commit 39e6f2f

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/replit_river/v2/session.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,17 @@ def close_session(reason: Exception | None) -> None:
268268
self.close(reason, current_state=current_state),
269269
)
270270

271-
def transition_connecting() -> None:
271+
def transition_connecting(ws: ClientConnection) -> None:
272272
if self._state in TerminalStates:
273273
return
274274
logger.debug("transition_connecting")
275275
self._state = SessionState.CONNECTING
276276
# "Clear" here means observers should wait until we are connected.
277277
self._wait_for_connected.clear()
278278

279+
# Expose the current ws to be collected by close()
280+
self._ws = ws
281+
279282
def transition_connected(ws: ClientConnection) -> None:
280283
if self._state in TerminalStates:
281284
return
@@ -1043,7 +1046,7 @@ async def _do_ensure_connected[HandshakeMetadata](
10431046
get_next_sent_seq: Callable[[], int],
10441047
get_current_ack: Callable[[], int],
10451048
get_state: Callable[[], SessionState],
1046-
transition_connecting: Callable[[], None],
1049+
transition_connecting: Callable[[ClientConnection], None],
10471050
close_ws_in_background: Callable[[ClientConnection], None],
10481051
transition_connected: Callable[[ClientConnection], None],
10491052
unbind_connecting_task: Callable[[], None],
@@ -1063,12 +1066,12 @@ async def _do_ensure_connected[HandshakeMetadata](
10631066
attempt_count += 1
10641067

10651068
rate_limiter.consume_budget(client_id)
1066-
transition_connecting()
10671069

10681070
ws: ClientConnection | None = None
10691071
try:
10701072
uri_and_metadata = await uri_and_metadata_factory()
10711073
ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"])
1074+
transition_connecting(ws)
10721075

10731076
try:
10741077
handshake_request = ControlMessageHandshakeRequest[HandshakeMetadata](

0 commit comments

Comments
 (0)