Skip to content

Commit 9c01f0e

Browse files
Better background task management
1 parent 0c07f6f commit 9c01f0e

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/replit_river/v2/session.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ def transition_connected(ws: ClientConnection) -> None:
226226
# We're connected, wake everybody up using set()
227227
self._wait_for_connected.set()
228228

229+
def close_ws_in_background(ws: ClientConnection) -> None:
230+
self._task_manager.create_task(ws.close())
231+
229232
def finalize_attempt() -> None:
230233
# We are in a state where we may throw an exception.
231234
#
@@ -239,7 +242,7 @@ def finalize_attempt() -> None:
239242
if (
240243
self._connecting_task
241244
and current_task
242-
and self._connecting_task.get_name() == current_task.get_name()
245+
and self._connecting_task is current_task
243246
):
244247
self._connecting_task = None
245248

@@ -257,6 +260,7 @@ def finalize_attempt() -> None:
257260
get_current_ack=lambda: self.ack,
258261
get_current_time=self._get_current_time,
259262
transition_connecting=transition_connecting,
263+
close_ws_in_background=close_ws_in_background,
260264
transition_connected=transition_connected,
261265
finalize_attempt=finalize_attempt,
262266
do_close=do_close,
@@ -973,6 +977,7 @@ async def _do_ensure_connected[HandshakeMetadata](
973977
get_next_sent_seq: Callable[[], int],
974978
get_current_ack: Callable[[], int],
975979
transition_connecting: Callable[[], None],
980+
close_ws_in_background: Callable[[ClientConnection], None],
976981
transition_connected: Callable[[ClientConnection], None],
977982
finalize_attempt: Callable[[], None],
978983
do_close: Callable[[], None],
@@ -989,7 +994,7 @@ async def _do_ensure_connected[HandshakeMetadata](
989994
rate_limiter.consume_budget(client_id)
990995
transition_connecting()
991996

992-
ws = None
997+
ws: ClientConnection | None = None
993998
try:
994999
uri_and_metadata = await uri_and_metadata_factory()
9951000
ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"])
@@ -1085,13 +1090,15 @@ async def websocket_closed_callback() -> None:
10851090
}",
10861091
)
10871092

1093+
# We did it! We're connected!
10881094
last_error = None
10891095
rate_limiter.start_restoring_budget(client_id)
10901096
transition_connected(ws)
10911097
break
10921098
except Exception as e:
10931099
if ws:
1094-
await ws.close()
1100+
close_ws_in_background(ws)
1101+
ws = None
10951102
last_error = e
10961103
backoff_time = rate_limiter.get_backoff_ms(client_id)
10971104
logger.exception(

0 commit comments

Comments
 (0)