Skip to content

Commit 6b17c42

Browse files
Avoid circular awaits
1 parent 3318e98 commit 6b17c42

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

src/replit_river/v2/session.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ class Session:
121121
ack: int # Most recently acknowledged seq
122122
seq: int # Last sent sequence number
123123

124+
# Terminating
125+
_terminating_task: asyncio.Task[None]
126+
124127
def __init__(
125128
self,
126129
transport_id: str,
@@ -188,13 +191,20 @@ async def ensure_connected[HandshakeMetadata](
188191
if self.is_connected():
189192
return
190193

194+
def do_close() -> None:
195+
# We can't just call self.close() directly because
196+
# we're inside a thread that will eventually be awaited
197+
# during the cleanup procedure.
198+
self._terminating_task = asyncio.create_task(self.close())
199+
191200
if not self._connecting_task:
192201
self._connecting_task = asyncio.create_task(
193202
self._do_ensure_connected(
194203
client_id,
195204
rate_limiter,
196205
uri_and_metadata_factory,
197206
protocol_version,
207+
do_close,
198208
)
199209
)
200210

@@ -208,6 +218,7 @@ async def _do_ensure_connected[HandshakeMetadata](
208218
[], Awaitable[UriAndMetadata[HandshakeMetadata]]
209219
], # noqa: E501
210220
protocol_version: str,
221+
do_close: Callable[[], None],
211222
) -> Literal[True]:
212223
max_retry = self._transport_options.connection_retry_options.max_retry
213224
logger.info("Attempting to establish new ws connection")
@@ -329,7 +340,8 @@ async def websocket_closed_callback() -> None:
329340
handshake_response.status.code
330341
== ERROR_CODE_SESSION_STATE_MISMATCH
331342
):
332-
await self.close()
343+
do_close()
344+
333345
raise RiverException(
334346
ERROR_HANDSHAKE,
335347
f"Handshake failed with code {handshake_response.status.code}: {
@@ -553,14 +565,20 @@ async def block_until_connected() -> None:
553565
)
554566

555567
def _start_close_session_checker(self) -> None:
568+
def do_close() -> None:
569+
# We can't just call self.close() directly because
570+
# we're inside a thread that will eventually be awaited
571+
# during the cleanup procedure.
572+
self._terminating_task = asyncio.create_task(self.close())
573+
556574
self._task_manager.create_task(
557575
_check_to_close_session(
558576
self._transport_id,
559577
self._transport_options.close_session_check_interval_ms,
560578
lambda: self._state,
561579
self._get_current_time,
562580
lambda: self._close_session_after_time_secs,
563-
self.close,
581+
do_close=do_close,
564582
)
565583
)
566584

@@ -986,7 +1004,7 @@ async def _check_to_close_session(
9861004
get_state: Callable[[], SessionState],
9871005
get_current_time: Callable[[], Awaitable[float]],
9881006
get_close_session_after_time_secs: Callable[[], float | None],
989-
do_close: Callable[[], Awaitable[None]],
1007+
do_close: Callable[[], None],
9901008
) -> None:
9911009
our_task = asyncio.current_task()
9921010
while our_task and not our_task.cancelling() and not our_task.cancelled():
@@ -1003,8 +1021,8 @@ async def _check_to_close_session(
10031021
continue
10041022
if current_time > close_session_after_time_secs:
10051023
logger.info("Grace period ended for %s, closing session", transport_id)
1006-
await do_close()
1007-
return
1024+
do_close()
1025+
our_task.cancel()
10081026

10091027

10101028
async def _buffered_message_sender(

0 commit comments

Comments
 (0)