Skip to content

Commit 41272a8

Browse files
Centralizing confusion around blocking vs non-blocking "close()" codepaths
1 parent 313467e commit 41272a8

File tree

1 file changed

+104
-95
lines changed

1 file changed

+104
-95
lines changed

src/replit_river/v2/session.py

Lines changed: 104 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -252,28 +252,6 @@ def get_next_sent_seq() -> int:
252252
return self._send_buffer[0].seq
253253
return self.seq
254254

255-
def close_session(reason: Exception | None) -> None:
256-
# If we're already closing, just let whoever's currently doing it handle it.
257-
if self._state in TerminalStates:
258-
return
259-
260-
# Avoid closing twice
261-
if self._terminating_task is None:
262-
current_state = self._state
263-
self._state = SessionState.CLOSING
264-
265-
# We can't just call self.close() directly because
266-
# we're inside a thread that will eventually be awaited
267-
# during the cleanup procedure.
268-
269-
self._terminating_task = asyncio.create_task(
270-
self.close(
271-
reason,
272-
current_state=current_state,
273-
_wait_for_closed=False,
274-
),
275-
)
276-
277255
def transition_connecting(ws: ClientConnection) -> None:
278256
if self._state in TerminalStates:
279257
return
@@ -328,7 +306,7 @@ def unbind_connecting_task() -> None:
328306
close_ws_in_background=close_ws_in_background,
329307
transition_connected=transition_connected,
330308
unbind_connecting_task=unbind_connecting_task,
331-
close_session=close_session,
309+
close_session=self._close_internal_nowait,
332310
)
333311
)
334312

@@ -413,14 +391,9 @@ async def _enqueue_message(
413391
async def close(
414392
self,
415393
reason: Exception | None = None,
416-
current_state: SessionState | None = None,
417-
_wait_for_closed: bool = True,
418394
) -> None:
419395
"""Close the session and all associated streams."""
420396
if self._closing_waiter:
421-
# Break early for internal callers
422-
if not _wait_for_closed:
423-
return
424397
try:
425398
logger.debug("Session already closing, waiting...")
426399
async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC):
@@ -431,79 +404,112 @@ async def close(
431404
"seconds to close, leaking",
432405
)
433406
return
434-
logger.info(
435-
f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}"
436-
)
437-
self._state = SessionState.CLOSING
438-
self._closing_waiter = asyncio.Event()
407+
await self._close_internal(reason)
439408

440-
# We're closing, so we need to wake up...
441-
# ... tasks waiting for connection to be established
442-
self._wait_for_connected.set()
443-
# ... consumers waiting to enqueue messages
444-
self._space_available.set()
445-
# ... message processor so it can exit cleanly
446-
self._process_messages.set()
409+
def _close_internal_nowait(self, reason: Exception | None = None) -> None:
410+
"""
411+
When calling close() from asyncio Tasks, we must not block.
412+
413+
This function does so, deferring to the underlying infrastructure for
414+
creating self._terminating_task.
415+
"""
416+
self._close_internal(reason)
417+
418+
def _close_internal(self, reason: Exception | None = None) -> asyncio.Task[None]:
419+
"""
420+
Internal close method. Subsequent calls past the first do not block.
421+
422+
This is intended to be the primary driver of a session being torn down
423+
and returned to its initial state.
424+
425+
NB: This function is intended to be the sole lifecycle manager of
426+
self._terminating_task. Waiting on the completion of that task is optional,
427+
but the population of that property is critical.
428+
429+
NB: We must not await the task returned from this function from chained tasks
430+
inside this session, otherwise we will create a thread loop.
431+
"""
432+
433+
async def do_close() -> None:
434+
logger.info(
435+
f"{self.session_id} closing session to {self._server_id}, "
436+
f"ws: {self._ws}"
437+
)
438+
self._state = SessionState.CLOSING
439+
self._closing_waiter = asyncio.Event()
447440

448-
# Wait to permit the waiting tasks to shut down gracefully
449-
await asyncio.sleep(0.25)
441+
# We're closing, so we need to wake up...
442+
# ... tasks waiting for connection to be established
443+
self._wait_for_connected.set()
444+
# ... consumers waiting to enqueue messages
445+
self._space_available.set()
446+
# ... message processor so it can exit cleanly
447+
self._process_messages.set()
448+
449+
# Wait to permit the waiting tasks to shut down gracefully
450+
await asyncio.sleep(0.25)
450451

451-
await self._task_manager.cancel_all_tasks()
452+
await self._task_manager.cancel_all_tasks()
452453

453-
for stream_meta in self._streams.values():
454-
stream_meta["output"].close()
455-
# Wake up backpressured writers
454+
for stream_meta in self._streams.values():
455+
stream_meta["output"].close()
456+
# Wake up backpressured writers
457+
try:
458+
stream_meta["error_channel"].put_nowait(
459+
reason
460+
or SessionClosedRiverServiceException(
461+
"river session is closed",
462+
)
463+
)
464+
except ChannelFull:
465+
logger.exception(
466+
"Unable to tell the caller that the session is going away",
467+
)
468+
stream_meta["release_backpressured_waiter"]()
469+
# Before we GC the streams, let's wait for all tasks to be closed gracefully
456470
try:
457-
stream_meta["error_channel"].put_nowait(
458-
reason
459-
or SessionClosedRiverServiceException(
460-
"river session is closed",
471+
async with asyncio.timeout(
472+
self._transport_options.shutdown_all_streams_timeout_ms
473+
):
474+
# Block for backpressure and emission errors from the ws
475+
await asyncio.gather(
476+
*[
477+
stream_meta["output"].join()
478+
for stream_meta in self._streams.values()
479+
]
461480
)
462-
)
463-
except ChannelFull:
481+
except asyncio.TimeoutError:
482+
spans: list[Span] = [
483+
stream_meta["span"]
484+
for stream_meta in self._streams.values()
485+
if not stream_meta["output"].closed()
486+
]
487+
span_ids = [span.get_span_context().span_id for span in spans]
464488
logger.exception(
465-
"Unable to tell the caller that the session is going away",
489+
"Timeout waiting for output streams to finallize",
490+
extra={"span_ids": span_ids},
466491
)
467-
stream_meta["release_backpressured_waiter"]()
468-
# Before we GC the streams, let's wait for all tasks to be closed gracefully.
469-
try:
470-
async with asyncio.timeout(
471-
self._transport_options.shutdown_all_streams_timeout_ms
472-
):
473-
# Block for backpressure and emission errors from the ws
474-
await asyncio.gather(
475-
*[
476-
stream_meta["output"].join()
477-
for stream_meta in self._streams.values()
478-
]
479-
)
480-
except asyncio.TimeoutError:
481-
spans: list[Span] = [
482-
stream_meta["span"]
483-
for stream_meta in self._streams.values()
484-
if not stream_meta["output"].closed()
485-
]
486-
span_ids = [span.get_span_context().span_id for span in spans]
487-
logger.exception(
488-
"Timeout waiting for output streams to finallize",
489-
extra={"span_ids": span_ids},
490-
)
491-
self._streams.clear()
492+
self._streams.clear()
492493

493-
if self._ws:
494-
# The Session isn't guaranteed to live much longer than this close()
495-
# invocation, so let's await this close to avoid dropping the socket.
496-
await self._ws.close()
494+
if self._ws:
495+
# The Session isn't guaranteed to live much longer than this close()
496+
# invocation, so let's await this close to avoid dropping the socket.
497+
await self._ws.close()
497498

498-
self._state = SessionState.CLOSED
499+
self._state = SessionState.CLOSED
499500

500-
# Clear the session in transports
501-
# This will get us GC'd, so this should be the last thing.
502-
self._close_session_callback(self)
501+
# Clear the session in transports
502+
# This will get us GC'd, so this should be the last thing.
503+
self._close_session_callback(self)
503504

504-
# Release waiters, then release the event
505-
self._closing_waiter.set()
506-
self._closing_waiter = None
505+
# Release waiters, then release the event
506+
self._closing_waiter.set()
507+
self._closing_waiter = None
508+
509+
if self._terminating_task:
510+
return self._terminating_task
511+
512+
return asyncio.create_task(do_close())
507513

508514
def _start_buffered_message_sender(
509515
self,
@@ -646,7 +652,7 @@ async def block_until_connected() -> None:
646652
get_state=lambda: self._state,
647653
get_ws=lambda: self._ws,
648654
transition_no_connection=transition_no_connection,
649-
close_session=lambda err: self.close(err, _wait_for_closed=False),
655+
close_session=self._close_internal_nowait,
650656
assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping,
651657
get_stream=lambda stream_id: self._streams.get(stream_id),
652658
enqueue_message=self._enqueue_message,
@@ -1137,8 +1143,11 @@ async def websocket_closed_callback() -> None:
11371143

11381144
try:
11391145
data = await ws.recv(decode=False)
1140-
except ConnectionClosedOK as e:
1141-
close_session(e)
1146+
except ConnectionClosedOK:
1147+
# In the case of a normal connection closure, we defer to
1148+
# the outer loop to determine next steps.
1149+
# A call to close(...) should set the SessionState to a terminal one,
1150+
# otherwise we should try again.
11421151
continue
11431152
except ConnectionClosed as e:
11441153
logger.debug(
@@ -1226,7 +1235,7 @@ async def _recv_from_ws(
12261235
get_state: Callable[[], SessionState],
12271236
get_ws: Callable[[], ClientConnection | None],
12281237
transition_no_connection: Callable[[], Awaitable[None]],
1229-
close_session: Callable[[Exception | None], Awaitable[None]],
1238+
close_session: Callable[[Exception | None], None],
12301239
assert_incoming_seq_bookkeeping: Callable[
12311240
[str, int, int], Literal[True] | _IgnoreMessage
12321241
],
@@ -1361,7 +1370,7 @@ async def _recv_from_ws(
13611370
stream_meta["output"].close()
13621371
except OutOfOrderMessageException:
13631372
logger.exception("Out of order message, closing connection")
1364-
await close_session(
1373+
close_session(
13651374
SessionClosedRiverServiceException(
13661375
"Out of order message, closing connection"
13671376
)
@@ -1371,7 +1380,7 @@ async def _recv_from_ws(
13711380
logger.exception(
13721381
"Got invalid transport message, closing session",
13731382
)
1374-
await close_session(
1383+
close_session(
13751384
SessionClosedRiverServiceException(
13761385
"Out of order message, closing connection"
13771386
)

0 commit comments

Comments
 (0)