Skip to content

Commit 636024e

Browse files
Expose in-flight ws to the session state to permit close()
1 parent 3931e80 commit 636024e

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/replit_river/v2/session.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,17 @@ def close_session(reason: Exception | None) -> None:
268268
self.close(reason, current_state=current_state),
269269
)
270270

271-
def transition_connecting() -> None:
271+
def transition_connecting(ws: ClientConnection) -> None:
272272
if self._state in TerminalStates:
273273
return
274274
logger.debug("transition_connecting")
275275
self._state = SessionState.CONNECTING
276276
# "Clear" here means observers should wait until we are connected.
277277
self._wait_for_connected.clear()
278278

279+
# Expose the current ws to be collected by close()
280+
self._ws = ws
281+
279282
def transition_connected(ws: ClientConnection) -> None:
280283
if self._state in TerminalStates:
281284
return
@@ -1108,7 +1111,7 @@ async def _do_ensure_connected[HandshakeMetadata](
11081111
get_next_sent_seq: Callable[[], int],
11091112
get_current_ack: Callable[[], int],
11101113
get_state: Callable[[], SessionState],
1111-
transition_connecting: Callable[[], None],
1114+
transition_connecting: Callable[[ClientConnection], None],
11121115
close_ws_in_background: Callable[[ClientConnection], None],
11131116
transition_connected: Callable[[ClientConnection], None],
11141117
unbind_connecting_task: Callable[[], None],
@@ -1128,12 +1131,12 @@ async def _do_ensure_connected[HandshakeMetadata](
11281131
attempt_count += 1
11291132

11301133
rate_limiter.consume_budget(client_id)
1131-
transition_connecting()
11321134

11331135
ws: ClientConnection | None = None
11341136
try:
11351137
uri_and_metadata = await uri_and_metadata_factory()
11361138
ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"])
1139+
transition_connecting(ws)
11371140

11381141
try:
11391142
handshake_request = ControlMessageHandshakeRequest[HandshakeMetadata](

0 commit comments

Comments
 (0)