diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 2bfc60ee..810a4971 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -1,10 +1,14 @@ +import asyncio import logging +from asyncio import Event, shield from collections.abc import Awaitable, Callable from typing import Generic import nanoid +from websockets.asyncio.client import ClientConnection from replit_river.rate_limiter import LeakyBucketRateLimit +from replit_river.task_manager import BackgroundTaskManager from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -17,6 +21,7 @@ class ClientTransport(Generic[HandshakeMetadataType]): _session: Session | None + _closing_event: tuple[Session, Event, Awaitable[None]] | None def __init__( self, @@ -28,6 +33,7 @@ def __init__( self._session = None self._transport_id = nanoid.generate() self._transport_options = transport_options + self._closing_event = None self._uri_and_metadata_factory = uri_and_metadata_factory self._client_id = client_id @@ -37,16 +43,57 @@ def __init__( ) async def close(self) -> None: + """ + A very simple function that only defers to session's close(), which + defers to the parameter we pass in to the Session constructor. + No logic in here. + """ self._rate_limiter.close() if self._session: - await self._session.close() - logger.info( - "Transport closed", - extra={ - "client_id": self._client_id, - "transport_id": self._transport_id, - }, - ) + await self._session.close().wait() + + if self._closing_event: + await self._closing_event[1].wait() + + def _trigger_close( + self, + signal_closing: Callable[[], None], + task_manager: BackgroundTaskManager, # .cancel_all_tasks() + terminate_remaining_output_streams: Callable[[], None], + join_output_streams_with_timeout: Callable[[], Awaitable[None]], + ws: ClientConnection | None, + become_closed: Callable[[], None], + ) -> Event: + if self._closing_event: + return self._closing_event[1] + if self._session is None: + noop = asyncio.Event() + noop.set() + return noop + + closing_event = Event() + + async def _do_close() -> None: + session = self._session + signal_closing() + await task_manager.cancel_all_tasks() + terminate_remaining_output_streams() + await join_output_streams_with_timeout() + if ws: + await ws.close() + become_closed() + # Ensure that we've not established a new session in the + # meantime somehow. + if self._session is session: + self._session = None + closing_event.set() + + self._closing_event = ( + self._session, + closing_event, + shield(asyncio.create_task(_do_close(), name="do_close")), + ) + return self._closing_event[1] async def get_or_create_session(self) -> Session: """ @@ -57,16 +104,23 @@ async def get_or_create_session(self) -> Session: if not existing_session or existing_session.is_terminal(): logger.info("Creating new session") if existing_session: - await existing_session.close() + await existing_session.close().wait() + if self._closing_event and self._closing_event[0] == existing_session: + await self._closing_event[2] + else: + logger.error( + "This should not be possible, " + "self._closing_event should always refer to existing_session", + ) new_session = Session( client_id=self._client_id, server_id=self._server_id, session_id=nanoid.generate(), transport_options=self._transport_options, - close_session_callback=self._delete_session, retry_connection_callback=self._retry_connection, uri_and_metadata_factory=self._uri_and_metadata_factory, rate_limiter=self._rate_limiter, + trigger_close=self._trigger_close, ) self._session = new_session @@ -78,21 +132,6 @@ async def get_or_create_session(self) -> Session: async def _retry_connection(self) -> Session: if self._session and not self._transport_options.transparent_reconnect: logger.info("transparent_reconnect not set, closing {self._transport_id}") - await self._session.close() + await self._session.close().wait() logger.debug("Triggering get_or_create_session") return await self.get_or_create_session() - - def _delete_session(self, session: Session) -> None: - if self._session is session: - self._session = None - else: - logger.warning( - "Session attempted to close itself but it was not the " - "active session, doing nothing", - extra={ - "client_id": self._client_id, - "transport_id": self._transport_id, - "active_session_id": self._session and self._session.session_id, - "orphan_session_id": session.session_id, - }, - ) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 9d8ce129..7c1ffd4b 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1,5 +1,6 @@ import asyncio import logging +import time from collections import deque from collections.abc import AsyncIterable from contextlib import asynccontextmanager @@ -14,6 +15,7 @@ Coroutine, Literal, NotRequired, + Protocol, TypeAlias, TypedDict, assert_never, @@ -112,7 +114,6 @@ class ResultError(TypedDict): trace_propagator = TraceContextTextMapPropagator() trace_setter = TransportMessageTracingSetter() -CloseSessionCallback: TypeAlias = Callable[["Session"], None] RetryConnectionCallback: TypeAlias = Callable[ [], Coroutine[Any, Any, Any], @@ -131,6 +132,18 @@ class StreamMeta(TypedDict): output: Channel[ResultType] +class TriggerCloseCall(Protocol): + def __call__( + self, + signal_closing: Callable[[], None], + task_manager: BackgroundTaskManager, # .cancel_all_tasks() + terminate_remaining_output_streams: Callable[[], None], + join_output_streams_with_timeout: Callable[[], Awaitable[None]], + ws: ClientConnection | None, + become_closed: Callable[[], None], + ) -> asyncio.Event: ... + + class Session[HandshakeMetadata]: _server_id: str session_id: str @@ -138,7 +151,6 @@ class Session[HandshakeMetadata]: # session state, only modified during closing _state: SessionState - _close_session_callback: CloseSessionCallback _close_session_after_time_secs: float | None _connecting_task: asyncio.Task[None] | None _wait_for_connected: asyncio.Event @@ -168,19 +180,19 @@ class Session[HandshakeMetadata]: seq: int # Last sent sequence number # Terminating - _terminating_task: asyncio.Task[None] | None + _trigger_close: TriggerCloseCall def __init__( self, server_id: str, session_id: str, transport_options: TransportOptions, - close_session_callback: CloseSessionCallback, client_id: str, rate_limiter: RateLimiter, uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ], + trigger_close: TriggerCloseCall, retry_connection_callback: RetryConnectionCallback | None = None, ) -> None: self._server_id = server_id @@ -189,7 +201,6 @@ def __init__( # session state self._state = SessionState.NO_CONNECTION - self._close_session_callback = close_session_callback self._close_session_after_time_secs: float | None = None self._connecting_task = None self._wait_for_connected = asyncio.Event() @@ -227,7 +238,7 @@ def __init__( self.seq = 0 # Terminating - self._terminating_task = None + self._trigger_close = trigger_close self._start_recv_from_ws() self._start_buffered_message_sender() @@ -298,13 +309,13 @@ def unbind_connecting_task() -> None: uri_and_metadata_factory=self._uri_and_metadata_factory, get_next_sent_seq=get_next_sent_seq, get_current_ack=lambda: self.ack, - get_current_time=self._get_current_time, + get_current_time=lambda: time.time(), get_state=lambda: self._state, transition_connecting=transition_connecting, close_ws_in_background=close_ws_in_background, transition_connected=transition_connected, unbind_connecting_task=unbind_connecting_task, - close_session=self._close_internal_nowait, + close_session=self.close, ) ) @@ -313,12 +324,6 @@ def unbind_connecting_task() -> None: except asyncio.CancelledError: pass - if self._terminating_task: - try: - await self._terminating_task - except asyncio.CancelledError: - pass - def is_terminal(self) -> bool: """ If the session is in a terminal state. @@ -330,9 +335,6 @@ def is_terminal(self) -> bool: def is_connected(self) -> bool: return self._state in ActiveStates - async def _get_current_time(self) -> float: - return asyncio.get_event_loop().time() - async def _enqueue_message( self, stream_id: str, @@ -394,52 +396,16 @@ async def _enqueue_message( # Wake up buffered_message_sender self._process_messages.set() - async def close( + def close( self, reason: Exception | None = None, - ) -> None: + ) -> asyncio.Event: """Close the session and all associated streams.""" - if self._terminating_task: - try: - logger.debug("Session already closing, waiting...") - async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC): - await self._terminating_task - except asyncio.TimeoutError: - logger.warning( - f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} " - "seconds to close, leaking", - ) - return - try: - await self._close_internal(reason) - except asyncio.CancelledError: - pass - - def _close_internal_nowait(self, reason: Exception | None = None) -> None: - """ - When calling close() from asyncio Tasks, we must not block. - - This function does so, deferring to the underlying infrastructure for - creating self._terminating_task. - """ - self._close_internal(reason) - - def _close_internal(self, reason: Exception | None = None) -> asyncio.Task[None]: - """ - Internal close method. Subsequent calls past the first do not block. - - This is intended to be the primary driver of a session being torn down - and returned to its initial state. - - NB: This function is intended to be the sole lifecycle manager of - self._terminating_task. Waiting on the completion of that task is optional, - but the population of that property is critical. - - NB: We must not await the task returned from this function from chained tasks - inside this session, otherwise we will create a thread loop. - """ - async def do_close() -> None: + def signal_closing() -> None: + """ + Roughly "kill 15" + """ logger.info( f"{self.session_id} closing session to {self._server_id}, " f"ws: {self._ws}" @@ -454,11 +420,10 @@ async def do_close() -> None: # ... message processor so it can exit cleanly self._process_messages.set() - # Wait to permit the waiting tasks to shut down gracefully - await asyncio.sleep(0.25) - - await self._task_manager.cancel_all_tasks() - + def terminate_remaining_output_streams() -> None: + """ + Roughly "kill 9" + """ for stream_id, stream_meta in self._streams.items(): stream_meta["output"].close() # Wake up backpressured writers @@ -475,6 +440,11 @@ async def do_close() -> None: "Unable to tell the caller that the session is going away", ) stream_meta["release_backpressured_waiter"]() + + async def join_output_streams_with_timeout() -> None: + """ + Roughly "wait" + """ # Before we GC the streams, let's wait for all tasks to be closed gracefully try: async with asyncio.timeout( @@ -500,21 +470,21 @@ async def do_close() -> None: ) self._streams.clear() - if self._ws: - # The Session isn't guaranteed to live much longer than this close() - # invocation, so let's await this close to avoid dropping the socket. - await self._ws.close() - + def become_closed() -> None: + pass self._state = SessionState.CLOSED # Clear the session in transports # This will get us GC'd, so this should be the last thing. - self._close_session_callback(self) - if not self._terminating_task: - self._terminating_task = asyncio.create_task(do_close()) - - return self._terminating_task + return self._trigger_close( + signal_closing, + self._task_manager, # .cancel_all_tasks() + terminate_remaining_output_streams, + join_output_streams_with_timeout, + self._ws, + become_closed, + ) def _start_buffered_message_sender( self, @@ -657,7 +627,7 @@ async def block_until_connected() -> None: get_state=lambda: self._state, get_ws=lambda: self._ws, transition_no_connection=transition_no_connection, - close_session=self._close_internal_nowait, + close_session=self.close, assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), enqueue_message=self._enqueue_message, @@ -860,6 +830,10 @@ async def send_upload[I, R, A]( span=span, ) raise + except SessionClosedRiverServiceException as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e except Exception as e: # If we get any exception other than WebsocketClosedException, # cancel the stream. @@ -1097,7 +1071,7 @@ async def _do_ensure_connected[HandshakeMetadata]( uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ], - get_current_time: Callable[[], Awaitable[float]], + get_current_time: Callable[[], float], get_next_sent_seq: Callable[[], int], get_current_ack: Callable[[], int], get_state: Callable[[], SessionState], @@ -1105,17 +1079,27 @@ async def _do_ensure_connected[HandshakeMetadata]( close_ws_in_background: Callable[[ClientConnection], None], transition_connected: Callable[[ClientConnection], None], unbind_connecting_task: Callable[[], None], - close_session: Callable[[Exception | None], None], + close_session: Callable[[Exception | None], asyncio.Event], ) -> None: logger.info("Attempting to establish new ws connection") last_error: Exception | None = None attempt_count = 0 - while rate_limiter.has_budget(client_id): + + handshake_deadline = ( + get_current_time() + transport_options.handshake_timeout_ms / 1000 + ) + + while ( + rate_limiter.has_budget(client_id) and get_current_time() < handshake_deadline + ): if (state := get_state()) in TerminalStates or state in ActiveStates: logger.info(f"_do_ensure_connected stopping due to state={state}") break + if (task := asyncio.current_task()) and task.cancelled(): + break + if attempt_count > 0: logger.info(f"Retrying build handshake number {attempt_count} times") attempt_count += 1 @@ -1169,18 +1153,16 @@ async def websocket_closed_callback() -> None: "Handshake failed, conn closed while sending response", ) from e - handshake_deadline_ms = ( - await get_current_time() + transport_options.handshake_timeout_ms - ) - - if await get_current_time() >= handshake_deadline_ms: + if get_current_time() >= handshake_deadline: raise RiverException( ERROR_HANDSHAKE, "Handshake response timeout, closing connection", ) try: - data = await ws.recv(decode=False) + timeout = handshake_deadline - get_current_time() + async with asyncio.timeout(timeout): + data = await ws.recv(decode=False) except ConnectionClosedOK: # In the case of a normal connection closure, we defer to # the outer loop to determine next steps. @@ -1197,6 +1179,16 @@ async def websocket_closed_callback() -> None: ERROR_HANDSHAKE, "Handshake failed, conn closed while waiting for response", ) from e + except asyncio.CancelledError as e: + logger.debug( + "_do_ensure_connected: Response timeout while waiting " + "for handshake response", + exc_info=True, + ) + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, timeout while waiting for response", + ) from e try: response_msg = parse_transport_msg(data) @@ -1244,15 +1236,15 @@ async def websocket_closed_callback() -> None: transition_connected(ws) break except Exception as e: - backoff_time = rate_limiter.get_backoff_ms(client_id) + backoff_time_ms = rate_limiter.get_backoff_ms(client_id) logger.exception( - f"Error connecting, retrying with {backoff_time}ms backoff" + f"Error connecting, retrying with {backoff_time_ms}ms backoff" ) if ws: close_ws_in_background(ws) ws = None last_error = e - await asyncio.sleep(backoff_time / 1000) + await asyncio.sleep(backoff_time_ms / 1000) logger.debug("Here, about to retry") unbind_connecting_task() @@ -1273,7 +1265,7 @@ async def _recv_from_ws( get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], transition_no_connection: Callable[[], Awaitable[None]], - close_session: Callable[[Exception | None], None], + close_session: Callable[[Exception | None], asyncio.Event], assert_incoming_seq_bookkeeping: Callable[ [str, int, int], Literal[True] | _IgnoreMessage ], diff --git a/tests/v2/test_v2_cancellation.py b/tests/v2/test_v2_cancellation.py index b259d27e..13e8adef 100644 --- a/tests/v2/test_v2_cancellation.py +++ b/tests/v2/test_v2_cancellation.py @@ -10,6 +10,7 @@ import msgpack import nanoid import pytest +from pydantic import TypeAdapter from replit_river.messages import parse_transport_msg from replit_river.rpc import ( @@ -24,6 +25,8 @@ from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT from tests.v2.fixtures.raw_ws_server import OuterPayload, WsServerFixture +logger = logging.getLogger(__file__) + async def test_rpc_cancel(ws_server: WsServerFixture) -> None: (urimeta, recv, conn) = ws_server @@ -226,7 +229,7 @@ async def _upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: lambda x: x, lambda x: x, ): - print(repr(chunk)) + logger.debug(repr(chunk)) receive_task = asyncio.create_task(receive_chunks()) request_msg = parse_transport_msg(await recv.get()) @@ -277,9 +280,9 @@ async def test_subscription_cancel(ws_server: WsServerFixture) -> None: assert not isinstance(request_msg, str) assert (serverconn := conn()) - handshake_request: ControlMessageHandshakeRequest[None] = ( - ControlMessageHandshakeRequest(**request_msg.payload) - ) + handshake_request = TypeAdapter( + ControlMessageHandshakeRequest[None] + ).validate_python(request_msg.payload) handshake_resp = ControlMessageHandshakeResponse( status=HandShakeStatus( @@ -358,7 +361,7 @@ async def receive_chunks() -> None: lambda x: x, lambda x: x, ): - print(repr(chunk)) + logger.debug(repr(chunk)) receive_task = asyncio.create_task(receive_chunks()) @@ -374,6 +377,9 @@ async def receive_chunks() -> None: await client.close() await connecting + # Wait until the close signal makes it back to the server + await asyncio.sleep(0.1) + # Ensure we're listening to close messages as well server_handler.cancel() await server_handler diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index 736a35f8..c13066e0 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -1,8 +1,10 @@ import asyncio import logging +from typing import Awaitable, Callable import msgpack import nanoid +from websockets.asyncio.client import ClientConnection from replit_river.common_session import SessionState from replit_river.messages import parse_transport_msg @@ -13,11 +15,14 @@ HandShakeStatus, TransportMessage, ) +from replit_river.task_manager import BackgroundTaskManager from replit_river.transport_options import TransportOptions from replit_river.v2.client import Client from replit_river.v2.session import STREAM_CLOSED_BIT, Session from tests.v2.fixtures.raw_ws_server import WsServerFixture +logger = logging.getLogger(__file__) + class _PermissiveRateLimiter(RateLimiter): def start_restoring_budget(self, user: str) -> None: @@ -36,21 +41,51 @@ def consume_budget(self, user: str) -> None: async def test_connect(ws_server: WsServerFixture) -> None: (urimeta, recv, conn) = ws_server + ws_close: asyncio.Task | None = None + + def trigger_close( + signal_closing: Callable[[], None], + task_manager: BackgroundTaskManager, # .cancel_all_tasks() + terminate_remaining_output_streams: Callable[[], None], + join_output_streams_with_timeout: Callable[[], Awaitable[None]], + ws: ClientConnection | None, + become_closed: Callable[[], None], + ) -> asyncio.Event: + nonlocal ws_close + + closing_event = asyncio.Event() + + async def _do_close() -> None: + signal_closing() + await task_manager.cancel_all_tasks() + terminate_remaining_output_streams() + await join_output_streams_with_timeout() + if ws: + await ws.close() + become_closed() + closing_event.set() + + ws_close = asyncio.create_task(_do_close()) + + return closing_event + session = Session( server_id="SERVER", session_id="SESSION1", transport_options=TransportOptions(), - close_session_callback=lambda _: None, client_id="CLIENT1", rate_limiter=_PermissiveRateLimiter(), uri_and_metadata_factory=urimeta, + trigger_close=trigger_close, ) connecting = asyncio.create_task(session.ensure_connected()) msg = parse_transport_msg(await recv.get()) assert isinstance(msg, TransportMessage) assert msg.payload["type"] == "HANDSHAKE_REQ" - await session.close() + await session.close().wait() + assert ws_close is not None + await ws_close await connecting @@ -59,28 +94,57 @@ async def test_close_race(ws_server: WsServerFixture) -> None: callcount = 0 - def close_session_callback(_session: Session) -> None: - nonlocal callcount - callcount += 1 + event: asyncio.Event | None = None + ws_close: asyncio.Task | None = None + + def trigger_close( + signal_closing: Callable[[], None], + task_manager: BackgroundTaskManager, # .cancel_all_tasks() + terminate_remaining_output_streams: Callable[[], None], + join_output_streams_with_timeout: Callable[[], Awaitable[None]], + ws: ClientConnection | None, + become_closed: Callable[[], None], + ) -> asyncio.Event: + nonlocal event + nonlocal ws_close + + if event is None: + event = asyncio.Event() + event.set() + nonlocal callcount + callcount += 1 + + async def _do_close() -> None: + signal_closing() + await task_manager.cancel_all_tasks() + terminate_remaining_output_streams() + await join_output_streams_with_timeout() + if ws: + await ws.close() + become_closed() + + ws_close = asyncio.create_task(_do_close()) + + return event session = Session( server_id="SERVER", session_id="SESSION1", transport_options=TransportOptions(), - close_session_callback=close_session_callback, client_id="CLIENT1", rate_limiter=_PermissiveRateLimiter(), uri_and_metadata_factory=urimeta, + trigger_close=trigger_close, ) connecting = asyncio.create_task(session.ensure_connected()) msg = parse_transport_msg(await recv.get()) assert isinstance(msg, TransportMessage) assert msg.payload["type"] == "HANDSHAKE_REQ" - await session.close() - await session.close() - await session.close() - await session.close() + await session.close().wait() + await session.close().wait() + await session.close().wait() + await session.close().wait() await connecting assert session._state == SessionState.CLOSED assert callcount == 1 @@ -160,9 +224,9 @@ async def handle_server_messages() -> None: async for datagram in client.send_subscription( "test", "bigstream", {}, lambda x: x, lambda x: x, lambda x: x ): - print(datagram) + logger.debug(datagram) except Exception: - logging.exception("Interrupted") + logger.exception("Interrupted") await client.close() await connecting