diff --git a/scripts/lint/src/lint/__init__.py b/scripts/lint/src/lint/__init__.py index f5b5c3d3..9d62a193 100644 --- a/scripts/lint/src/lint/__init__.py +++ b/scripts/lint/src/lint/__init__.py @@ -11,7 +11,8 @@ def raise_err(code: int) -> None: def main() -> None: fix = ["--fix"] if "--fix" in sys.argv else [] + watch = ["--watch"] if "--watch" in sys.argv else [] raise_err(os.system(" ".join(["ruff", "check", "src", "scripts", "tests"] + fix))) raise_err(os.system("ruff format src scripts tests")) - raise_err(os.system("mypy src")) - raise_err(os.system("pyright src")) + raise_err(os.system("mypy src tests")) + raise_err(os.system(" ".join(["pyright"] + watch + ["src", "tests"]))) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 599ac5c3..2d1e847a 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -2,16 +2,15 @@ import logging from collections.abc import AsyncIterable from datetime import timedelta -from typing import Any, AsyncGenerator, Callable, Coroutine +from typing import Any, AsyncGenerator, Callable, Coroutine, assert_never -import nanoid # type: ignore +import nanoid import websockets from aiochannel import Channel from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span from websockets.exceptions import ConnectionClosed -from replit_river.common_session import add_msg_to_stream from replit_river.error_schema import ( ERROR_CODE_CANCEL, ERROR_CODE_STREAM_CLOSED, @@ -25,7 +24,7 @@ parse_transport_msg, ) from replit_river.seq_manager import ( - IgnoreMessageException, + IgnoreMessage, InvalidMessageException, OutOfOrderMessageException, ) @@ -34,7 +33,6 @@ from .rpc import ( ACK_BIT, - STREAM_CLOSED_BIT, STREAM_OPEN_BIT, ErrorType, InitType, @@ -45,6 +43,9 @@ logger = logging.getLogger(__name__) +STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2 + + class ClientSession(Session): def __init__( self, @@ -81,6 +82,13 @@ async def do_close_websocket() -> None: self._setup_heartbeats_task(do_close_websocket) + async def replace_with_new_websocket( + self, new_ws: websockets.WebSocketCommonProtocol + ) -> None: + await super().replace_with_new_websocket(new_ws) + # serve() terminates itself when the ws dies, so we need to start it again + await self.start_serve_responses() + async def start_serve_responses(self) -> None: self._task_manager.create_task(self.serve()) @@ -120,15 +128,25 @@ async def _handle_messages_from_ws(self) -> None: ws_wrapper = self._ws_wrapper async for message in ws_wrapper.ws: try: - if not await ws_wrapper.is_open(): + if not ws_wrapper.is_open(): # We should not process messages if the websocket is closed. break - msg = parse_transport_msg(message, self._transport_options) + msg = parse_transport_msg(message) + if isinstance(msg, str): + logger.debug("Ignoring transport message", exc_info=True) + continue logger.debug(f"{self._transport_id} got a message %r", msg) # Update bookkeeping - await self._seq_manager.check_seq_and_update(msg) + match self._seq_manager.check_seq_and_update(msg): + case IgnoreMessage(): + continue + case None: + pass + case other: + assert_never(other) + await self._buffer.remove_old_messages( self._seq_manager.receiver_ack, ) @@ -136,15 +154,28 @@ async def _handle_messages_from_ws(self) -> None: if msg.controlFlags & ACK_BIT != 0: continue - async with self._stream_lock: - stream = self._streams.get(msg.streamId, None) + stream = self._streams.get(msg.streamId, None) if msg.controlFlags & STREAM_OPEN_BIT == 0: if not stream: logger.warning("no stream for %s", msg.streamId) - raise IgnoreMessageException( - "no stream for message, ignoring" - ) - await add_msg_to_stream(msg, stream) + continue + + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e + else: raise InvalidMessageException( "Client should not receive stream open bit" @@ -153,11 +184,7 @@ async def _handle_messages_from_ws(self) -> None: if msg.controlFlags & STREAM_CLOSED_BIT != 0: if stream: stream.close() - async with self._stream_lock: - del self._streams[msg.streamId] - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue + del self._streams[msg.streamId] except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await ws_wrapper.close() diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 56158fcf..04215811 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -19,7 +19,6 @@ RiverException, ) from replit_river.messages import ( - PROTOCOL_VERSION, FailedSendingMessageException, WebsocketClosedException, parse_transport_msg, @@ -34,7 +33,6 @@ TransportMessage, ) from replit_river.seq_manager import ( - IgnoreMessageException, InvalidMessageException, ) from replit_river.session import Session @@ -44,6 +42,8 @@ UriAndMetadata, ) +PROTOCOL_VERSION = "v1.1" + logger = logging.getLogger(__name__) @@ -116,6 +116,8 @@ async def get_or_create_session(self) -> ClientSession: return existing_session else: logger.info("Closing stale session %s", existing_session.session_id) + await new_ws.close() # NB(dstewart): This wasn't there in the + # v1 transport, were we just leaking WS? await existing_session.close() return await self._create_new_session() @@ -293,10 +295,11 @@ async def _get_handshake_response_msg( "Handshake failed, conn closed while waiting for response", ) from e try: - return parse_transport_msg(data, self._transport_options) - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue + msg = parse_transport_msg(data) + if isinstance(msg, str): + logger.debug("Ignoring transport message", exc_info=True) + continue + return msg except InvalidMessageException as e: raise RiverException( ERROR_HANDSHAKE, diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 453b7a53..ea4dec85 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -802,6 +802,7 @@ def generate_individual_service( schema_name: str, schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], + method_filter: set[str] | None, ) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] @@ -837,6 +838,8 @@ def __init__(self, client: river.Client[Any]): ), ] for name, procedure in schema.procedures.items(): + if method_filter and (schema_name + "." + name) in method_filter: + continue module_names = [ModuleName(name)] init_type: TypeExpression | None = None if procedure.init: @@ -980,7 +983,7 @@ def __init__(self, client: river.Client[Any]): ) render_init_method = f"""\ lambda x: {render_type_expr(init_type_type_adapter_name)} - .validate_python + .validate_python(x) """ assert init_type is None or render_init_method, ( @@ -1223,6 +1226,7 @@ def generate_river_client_module( client_name: str, schema_root: RiverSchema, typed_dict_inputs: bool, + method_filter: set[str] | None, ) -> dict[RenderedPath, FileContents]: files: dict[RenderedPath, FileContents] = {} @@ -1247,10 +1251,15 @@ def generate_river_client_module( ) for schema_name, schema in schema_root.services.items(): module_name, class_name, emitted_files = generate_individual_service( - schema_name, schema, input_base_class + schema_name, + schema, + input_base_class, + method_filter, ) - files.update(emitted_files) - modules.append((module_name, class_name)) + if emitted_files: + # Short-cut if we didn't actually emit anything + files.update(emitted_files) + modules.append((module_name, class_name)) main_contents = generate_common_client( client_name, handshake_type, handshake_chunks, modules @@ -1266,12 +1275,16 @@ def schema_to_river_client_codegen( client_name: str, typed_dict_inputs: bool, file_opener: Callable[[Path], TextIO], + method_filter: set[str] | None, ) -> None: """Generates the lines of a River module.""" with read_schema() as f: schemas = RiverSchemaFile(json.load(f)) for subpath, contents in generate_river_client_module( - client_name, schemas.root, typed_dict_inputs + client_name, + schemas.root, + typed_dict_inputs, + method_filter, ).items(): module_path = Path(target_path).joinpath(subpath) module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True) diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index c9ab6384..16d94a5e 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -1,5 +1,6 @@ import argparse import os.path +import pathlib from pathlib import Path from typing import TextIO @@ -38,6 +39,12 @@ def main() -> None: action="store_true", default=False, ) + client.add_argument( + "--method-filter", + help="Only generate a subset of the specified methods", + action="store", + type=pathlib.Path, + ) client.add_argument("schema", help="schema file") args = parser.parse_args() @@ -56,12 +63,18 @@ def main() -> None: def file_opener(path: Path) -> TextIO: return open(path, "w") + method_filter: set[str] | None = None + if args.method_filter: + with open(args.method_filter) as handle: + method_filter = set(x.strip() for x in handle.readlines()) + schema_to_river_client_codegen( lambda: open(schema_path), target_path, args.client_name, args.typed_dict_inputs, file_opener, + method_filter=method_filter, ) else: raise NotImplementedError(f"Unknown command {args.command}") diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 7193733d..2325492e 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -1,15 +1,9 @@ -import asyncio import enum import logging -from typing import Any, Awaitable, Callable, Protocol +from typing import Any, Protocol -from aiochannel import Channel, ChannelClosed from opentelemetry.trace import Span -from replit_river.messages import FailedSendingMessageException -from replit_river.rpc import ACK_BIT, STREAM_CLOSED_BIT, TransportMessage -from replit_river.seq_manager import InvalidMessageException - logger = logging.getLogger(__name__) @@ -29,107 +23,18 @@ async def __call__( class SessionState(enum.Enum): """The state a session can be in. - Can only transition from ACTIVE to CLOSING to CLOSED. + Valid transitions: + - NO_CONNECTION -> {ACTIVE} + - ACTIVE -> {NO_CONNECTION, CLOSING} + - CLOSING -> {CLOSED} + - CLOSED -> {} """ - ACTIVE = 0 - CLOSING = 1 - CLOSED = 2 - - -async def setup_heartbeat( - session_id: str, - heartbeat_ms: float, - heartbeats_until_dead: int, - get_state: Callable[[], SessionState], - get_closing_grace_period: Callable[[], float | None], - close_websocket: Callable[[], Awaitable[None]], - send_message: SendMessage, - increment_and_get_heartbeat_misses: Callable[[], int], -) -> None: - logger.debug("Start heartbeat") - while True: - await asyncio.sleep(heartbeat_ms / 1000) - state = get_state() - if state != SessionState.ACTIVE: - logger.debug( - "Session is closed, no need to send heartbeat, state : " - "%r close_session_after_this: %r", - {state}, - {get_closing_grace_period()}, - ) - # session is closing / closed, no need to send heartbeat anymore - return - try: - await send_message( - stream_id="heartbeat", - # TODO: make this a message class - # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 - payload={ - "ack": 0, - }, - control_flags=ACK_BIT, - procedure_name=None, - service_name=None, - span=None, - ) - - if increment_and_get_heartbeat_misses() > heartbeats_until_dead: - if get_closing_grace_period() is not None: - # already in grace period, no need to set again - continue - logger.info( - "%r closing websocket because of heartbeat misses", - session_id, - ) - await close_websocket() - continue - except FailedSendingMessageException: - # this is expected during websocket closed period - continue - - -async def check_to_close_session( - transport_id: str, - close_session_check_interval_ms: float, - get_state: Callable[[], SessionState], - get_current_time: Callable[[], Awaitable[float]], - get_close_session_after_time_secs: Callable[[], float | None], - do_close: Callable[[], Awaitable[None]], -) -> None: - while True: - await asyncio.sleep(close_session_check_interval_ms / 1000) - if get_state() != SessionState.ACTIVE: - # already closing - return - # calculate the value now before comparing it so that there are no - # await points between the check and the comparison to avoid a TOCTOU - # race. - current_time = await get_current_time() - close_session_after_time_secs = get_close_session_after_time_secs() - if not close_session_after_time_secs: - continue - if current_time > close_session_after_time_secs: - logger.info("Grace period ended for %s, closing session", transport_id) - await do_close() - return + NO_CONNECTION = 0 + ACTIVE = 1 + CLOSING = 2 + CLOSED = 3 -async def add_msg_to_stream( - msg: TransportMessage, - stream: Channel, -) -> None: - if ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - return - try: - await stream.put(msg.payload) - except ChannelClosed: - # The client is no longer interested in this stream, - # just drop the message. - pass - except RuntimeError as e: - raise InvalidMessageException(e) from e +ConnectingStates = set([SessionState.NO_CONNECTION]) +TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index 8bcf023c..6e1fdad7 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -17,46 +17,46 @@ class MessageBuffer: def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE): self.max_size = max_num_messages self.buffer: list[TransportMessage] = [] - self._lock = asyncio.Lock() - self._space_available_cond = asyncio.Condition(lock=self._lock) + self._space_available_cond = asyncio.Condition() self._closed = False - async def empty(self) -> bool: - """Check if the buffer is empty""" - async with self._lock: - return len(self.buffer) == 0 + async def has_capacity(self) -> None: + async with self._space_available_cond: + await self._space_available_cond.wait_for( + lambda: len(self.buffer) < self.max_size or self._closed + ) - async def put(self, message: TransportMessage) -> None: + def put(self, message: TransportMessage) -> None: """Add a message to the buffer. Blocks until there is space in the buffer. Raises: MessageBufferClosedError: if the buffer is closed. """ - async with self._space_available_cond: - await self._space_available_cond.wait_for( - lambda: len(self.buffer) < self.max_size or self._closed - ) - if self._closed: - raise MessageBufferClosedError("message buffer is closed") - self.buffer.append(message) + if self._closed: + raise MessageBufferClosedError("message buffer is closed") + self.buffer.append(message) + + def get_next_sent_seq(self) -> int | None: + if self.buffer: + return self.buffer[0].seq + return None - async def peek(self) -> TransportMessage | None: + def peek(self) -> TransportMessage | None: """Peek the first message in the buffer, returns None if the buffer is empty.""" - async with self._lock: - if len(self.buffer) == 0: - return None - return self.buffer[0] + if len(self.buffer) == 0: + return None + return self.buffer[0] async def remove_old_messages(self, min_seq: int) -> None: """Remove messages in the buffer with a seq number less than min_seq.""" - async with self._lock: - self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq] + self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq] + async with self._space_available_cond: self._space_available_cond.notify_all() async def close(self) -> None: """ Closes the message buffer and rejects any pending put operations. """ - async with self._lock: - self._closed = True + self._closed = True + async with self._space_available_cond: self._space_available_cond.notify_all() diff --git a/src/replit_river/messages.py b/src/replit_river/messages.py index fc8e608a..9cdf324a 100644 --- a/src/replit_river/messages.py +++ b/src/replit_river/messages.py @@ -13,10 +13,8 @@ TransportMessage, ) from replit_river.seq_manager import ( - IgnoreMessageException, InvalidMessageException, ) -from replit_river.transport_options import TransportOptions logger = logging.getLogger(__name__) @@ -29,9 +27,6 @@ class FailedSendingMessageException(Exception): pass -PROTOCOL_VERSION = "v1.1" - - async def send_transport_message( msg: TransportMessage, ws: WebSocketCommonProtocol, @@ -62,13 +57,9 @@ def formatted_bytes(message: bytes) -> str: return " ".join(f"{b:02x}" for b in message) -def parse_transport_msg( - message: str | bytes, transport_options: TransportOptions -) -> TransportMessage: +def parse_transport_msg(message: str | bytes) -> TransportMessage | str: if isinstance(message, str): - raise IgnoreMessageException( - f"ignored a message beacuse it was a text frame: {message}" - ) + return message try: # :param int timestamp: # Control how timestamp type is unpacked: diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index 0d1bd4d1..3459bf7d 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -2,6 +2,7 @@ import logging from collections.abc import AsyncIterable, AsyncIterator from typing import ( + Annotated, Any, Awaitable, Callable, @@ -46,9 +47,8 @@ GenericRpcHandlerBuilder = Callable[ [str, Channel[Any], Channel[Any]], Coroutine[None, None, None] ] -ACK_BIT = 0x0001 -STREAM_OPEN_BIT = 0x0002 -STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2 +ACK_BIT = 0b00001 +STREAM_OPEN_BIT = 0b00010 # these codes are retriable # if the server sends a response with one of these codes, @@ -92,7 +92,7 @@ class PropagationContext(BaseModel): class TransportMessage(BaseModel): id: str # from_ is used instead of from because from is a reserved keyword in Python - from_: str = Field(..., alias="from") + from_: Annotated[str, Field(alias="from")] to: str seq: int ack: int diff --git a/src/replit_river/seq_manager.py b/src/replit_river/seq_manager.py index fa75b448..8a2f6798 100644 --- a/src/replit_river/seq_manager.py +++ b/src/replit_river/seq_manager.py @@ -1,17 +1,11 @@ -import asyncio import logging +from dataclasses import dataclass from replit_river.rpc import TransportMessage logger = logging.getLogger(__name__) -class IgnoreMessageException(Exception): - """Exception to ignore a transport message, but good to continue.""" - - pass - - class InvalidMessageException(Exception): """Error processing a transport message, should raise a exception.""" @@ -34,62 +28,51 @@ class SessionStateMismatchException(Exception): pass +@dataclass +class IgnoreMessage: + pass + + class SeqManager: """Manages the sequence number and ack number for a connection.""" def __init__( self, ) -> None: - self._seq_lock = asyncio.Lock() self.seq = 0 - self._ack_lock = asyncio.Lock() self.ack = 0 self.receiver_ack = 0 - async def get_seq_and_increment(self) -> int: + def get_seq_and_increment(self) -> int: """Get the current sequence number and increment it. This removes one lock acquire than get_seq and increment_seq separately. """ - async with self._seq_lock: - current_value = self.seq - self.seq += 1 - return current_value - - async def increment_seq(self) -> int: - async with self._seq_lock: - self.seq += 1 - return self.seq - - async def get_seq(self) -> int: - async with self._seq_lock: - return self.seq - - async def get_ack(self) -> int: - async with self._ack_lock: - return self.ack - - async def check_seq_and_update(self, msg: TransportMessage) -> None: - async with self._ack_lock: - if msg.seq != self.ack: - if msg.seq < self.ack: - raise IgnoreMessageException( - f"{msg.from_} received duplicate msg, got {msg.seq}" - f" expected {self.ack}" - ) - else: - logger.warn( - f"Out of order message received got {msg.seq} expected " - f"{self.ack}" - ) - - raise OutOfOrderMessageException( - f"Out of order message received got {msg.seq} expected " - f"{self.ack}" - ) - self.receiver_ack = msg.ack - await self._set_ack(msg.seq + 1) - - async def _set_ack(self, new_ack: int) -> int: - async with self._ack_lock: - self.ack = new_ack - return self.ack + current_value = self.seq + self.seq += 1 + return current_value + + def increment_seq(self) -> int: + self.seq += 1 + return self.seq + + def get_seq(self) -> int: + return self.seq + + def get_ack(self) -> int: + return self.ack + + def check_seq_and_update(self, msg: TransportMessage) -> IgnoreMessage | None: + if msg.seq != self.ack: + if msg.seq < self.ack: + return IgnoreMessage() + else: + logger.warning( + f"Out of order message received got {msg.seq} expected {self.ack}" + ) + + raise OutOfOrderMessageException( + f"Out of order message received got {msg.seq} expected {self.ack}" + ) + self.receiver_ack = msg.ack + self.ack = msg.seq + 1 + return None diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 868ff0fb..3a931274 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -1,19 +1,18 @@ import asyncio import logging -from typing import Any, Callable, Coroutine +from typing import Any, Callable, Coroutine, assert_never import websockets from aiochannel import Channel, ChannelClosed from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from websockets.exceptions import ConnectionClosed -from replit_river.common_session import add_msg_to_stream from replit_river.messages import ( FailedSendingMessageException, parse_transport_msg, ) from replit_river.seq_manager import ( - IgnoreMessageException, + IgnoreMessage, InvalidMessageException, OutOfOrderMessageException, ) @@ -22,18 +21,18 @@ from .rpc import ( ACK_BIT, - STREAM_CLOSED_BIT, STREAM_OPEN_BIT, GenericRpcHandlerBuilder, TransportMessage, TransportMessageTracingSetter, ) -logger = logging.getLogger(__name__) +STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2 logger = logging.getLogger(__name__) + trace_propagator = TraceContextTextMapPropagator() trace_setter = TransportMessageTracingSetter() @@ -119,15 +118,24 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: ws_wrapper = self._ws_wrapper async for message in ws_wrapper.ws: try: - if not await ws_wrapper.is_open(): + if not ws_wrapper.is_open(): # We should not process messages if the websocket is closed. break - msg = parse_transport_msg(message, self._transport_options) + msg = parse_transport_msg(message) + if isinstance(msg, str): + logger.debug("Ignoring transport message", exc_info=True) + continue logger.debug(f"{self._transport_id} got a message %r", msg) # Update bookkeeping - await self._seq_manager.check_seq_and_update(msg) + match self._seq_manager.check_seq_and_update(msg): + case IgnoreMessage(): + continue + case None: + pass + case other: + assert_never(other) await self._buffer.remove_old_messages( self._seq_manager.receiver_ack, ) @@ -135,30 +143,40 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: if msg.controlFlags & ACK_BIT != 0: continue - async with self._stream_lock: - stream = self._streams.get(msg.streamId, None) + stream = self._streams.get(msg.streamId) if msg.controlFlags & STREAM_OPEN_BIT == 0: if not stream: logger.warning("no stream for %s", msg.streamId) - raise IgnoreMessageException( - "no stream for message, ignoring" - ) - await add_msg_to_stream(msg, stream) + continue + + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e + else: _stream = await self._open_stream_and_call_handler(msg, tg) + if isinstance(_stream, IgnoreMessage): + continue if not stream: - async with self._stream_lock: - self._streams[msg.streamId] = _stream + self._streams[msg.streamId] = _stream stream = _stream if msg.controlFlags & STREAM_CLOSED_BIT != 0: if stream: stream.close() - async with self._stream_lock: - del self._streams[msg.streamId] - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue + del self._streams[msg.streamId] except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await ws_wrapper.close() @@ -173,18 +191,23 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: async def _open_stream_and_call_handler( self, msg: TransportMessage, - tg: asyncio.TaskGroup | None, - ) -> Channel: + tg: asyncio.TaskGroup, + ) -> Channel | IgnoreMessage: if not msg.serviceName or not msg.procedureName: - raise IgnoreMessageException( - f"Service name or procedure name is missing in the message {msg}" + logger.warning( + "Service name or procedure name is missing in the message %r", + msg, ) + return IgnoreMessage() key = (msg.serviceName, msg.procedureName) handler = self._handlers.get(key, None) if not handler: - raise IgnoreMessageException( - f"No handler for {key} handlers : {self._handlers.keys()}" + logger.warning( + "No handler for %r handlers: %r", + key, + self._handlers.keys(), ) + return IgnoreMessage() method_type, handler_func = handler is_streaming_output = method_type in ( "subscription-stream", # subscription diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 88e848a7..7b15e36e 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -2,7 +2,7 @@ import logging from typing import Any -import nanoid # type: ignore # type: ignore +import nanoid from pydantic import ValidationError from websockets import ( WebSocketCommonProtocol, @@ -11,7 +11,6 @@ from websockets.exceptions import ConnectionClosed from replit_river.messages import ( - PROTOCOL_VERSION, FailedSendingMessageException, WebsocketClosedException, parse_transport_msg, @@ -25,7 +24,6 @@ TransportMessage, ) from replit_river.seq_manager import ( - IgnoreMessageException, InvalidMessageException, SessionStateMismatchException, ) @@ -33,6 +31,8 @@ from replit_river.session import Session from replit_river.transport_options import TransportOptions +PROTOCOL_VERSION = "v1.1" + logger = logging.getLogger(__name__) @@ -72,13 +72,13 @@ async def handshake_to_get_session( ) -> ServerSession: async for message in websocket: try: - msg = parse_transport_msg(message, self._transport_options) + msg = parse_transport_msg(message) + if isinstance(msg, str): + continue ( handshake_request, handshake_response, ) = await self._establish_handshake(msg, websocket) - except IgnoreMessageException: - continue except InvalidMessageException as e: error_msg = "Got invalid transport message, closing connection" raise InvalidMessageException(error_msg) from e diff --git a/src/replit_river/session.py b/src/replit_river/session.py index d908bdda..ac01ffba 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -1,17 +1,17 @@ import asyncio import logging -from typing import Any, Awaitable, Callable, Coroutine +from typing import Any, Awaitable, Callable, Coroutine, TypeAlias -import nanoid # type: ignore +import nanoid import websockets from aiochannel import Channel from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from replit_river.common_session import ( + SendMessage, SessionState, - check_to_close_session, - setup_heartbeat, + TerminalStates, ) from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( @@ -27,6 +27,7 @@ from replit_river.websocket_wrapper import WebsocketWrapper from .rpc import ( + ACK_BIT, TransportMessage, TransportMessageTracingSetter, ) @@ -36,10 +37,41 @@ trace_propagator = TraceContextTextMapPropagator() trace_setter = TransportMessageTracingSetter() +CloseSessionCallback: TypeAlias = Callable[["Session"], Coroutine[Any, Any, Any]] +RetryConnectionCallback: TypeAlias = Callable[ + [], + Coroutine[Any, Any, Any], +] + class Session: """Common functionality shared between client_session and server_session""" + _transport_id: str + _to_id: str + session_id: str + _transport_options: TransportOptions + + # session state + _state: SessionState + _state_lock: asyncio.Lock + _close_session_callback: CloseSessionCallback + _close_session_after_time_secs: float | None + + # ws state + _ws_lock: asyncio.Lock + _ws_wrapper: WebsocketWrapper + _heartbeat_misses: int + _retry_connection_callback: RetryConnectionCallback | None + + # stream for tasks + _streams: dict[str, Channel[Any]] + + # book keeping + _seq_manager: SeqManager + _buffer: MessageBuffer + _task_manager: BackgroundTaskManager + def __init__( self, transport_id: str, @@ -47,21 +79,15 @@ def __init__( session_id: str, websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, - close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], - retry_connection_callback: ( - Callable[ - [], - Coroutine[Any, Any, Any], - ] - | None - ) = None, + close_session_callback: CloseSessionCallback, + retry_connection_callback: RetryConnectionCallback | None = None, ) -> None: self._transport_id = transport_id self._to_id = to_id self.session_id = session_id self._transport_options = transport_options - # session state, only modified during closing + # session state self._state = SessionState.ACTIVE self._state_lock = asyncio.Lock() self._close_session_callback = close_session_callback @@ -74,12 +100,10 @@ def __init__( self._retry_connection_callback = retry_connection_callback # stream for tasks - self._stream_lock = asyncio.Lock() self._streams: dict[str, Channel[Any]] = {} # book keeping self._seq_manager = SeqManager() - self._msg_lock = asyncio.Lock() self._buffer = MessageBuffer(self._transport_options.buffer_size) self._task_manager = BackgroundTaskManager() @@ -120,7 +144,7 @@ async def is_session_open(self) -> bool: async def is_websocket_open(self) -> bool: async with self._ws_lock: - return await self._ws_wrapper.is_open() + return self._ws_wrapper.is_open() async def _begin_close_session_countdown(self) -> None: """Begin the countdown to close session, this should be called when @@ -157,9 +181,10 @@ async def replace_with_new_websocket( buffered_messages = list(self._buffer.buffer) for msg in buffered_messages: try: - await self._send_transport_message( + await send_transport_message( msg, new_ws, + self._begin_close_session_countdown, ) except WebsocketClosedException: logger.info( @@ -177,34 +202,17 @@ def _reset_session_close_countdown(self) -> None: self._heartbeat_misses = 0 self._close_session_after_time_secs = None - async def _send_transport_message( - self, - msg: TransportMessage, - websocket: websockets.WebSocketCommonProtocol, - ) -> None: - try: - await send_transport_message( - msg, websocket, self._begin_close_session_countdown - ) - except WebsocketClosedException as e: - raise e - except FailedSendingMessageException as e: - raise e - async def get_next_expected_seq(self) -> int: """Get the next expected sequence number from the server.""" - return await self._seq_manager.get_ack() + return self._seq_manager.get_ack() async def get_next_sent_seq(self) -> int: """Get the next sequence number that the client will send.""" - nextMessage = await self._buffer.peek() - if nextMessage: - return nextMessage.seq - return await self._seq_manager.get_seq() + return self._buffer.get_next_sent_seq() or self._seq_manager.get_seq() async def get_next_expected_ack(self) -> int: """Get the next expected ack that the client expects.""" - return await self._seq_manager.get_seq() + return self._seq_manager.get_seq() async def send_message( self, @@ -219,13 +227,15 @@ async def send_message( # if the session is not active, we should not do anything if self._state != SessionState.ACTIVE: return + await self._buffer.has_capacity() + # Start of critical section. No await between here and buffer.put()! msg = TransportMessage( streamId=stream_id, id=nanoid.generate(), from_=self._transport_id, # type: ignore to=self._to_id, - seq=await self._seq_manager.get_seq_and_increment(), - ack=await self._seq_manager.get_ack(), + seq=self._seq_manager.get_seq_and_increment(), + ack=self._seq_manager.get_ack(), controlFlags=control_flags, payload=payload, serviceName=service_name, @@ -235,23 +245,19 @@ async def send_message( with use_span(span): trace_propagator.inject(msg, None, trace_setter) try: - # We need this lock to ensure the buffer order and message sending order - # are the same. - async with self._msg_lock: - try: - await self._buffer.put(msg) - except MessageBufferClosedError: - # The session is closed and is no longer accepting new messages. + try: + self._buffer.put(msg) + except MessageBufferClosedError: + # The session is closed and is no longer accepting new messages. + return + async with self._ws_lock: + if not self._ws_wrapper.is_open(): + # If the websocket is closed, we should not send the message + # and wait for the retry from the buffer. return - async with self._ws_lock: - if not await self._ws_wrapper.is_open(): - # If the websocket is closed, we should not send the message - # and wait for the retry from the buffer. - return - await self._send_transport_message( - msg, - self._ws_wrapper.ws, - ) + await send_transport_message( + msg, self._ws_wrapper.ws, self._begin_close_session_countdown + ) except WebsocketClosedException as e: logger.debug( "Connection closed while sending message %r, waiting for " @@ -270,7 +276,7 @@ async def close_websocket( """Mark the websocket as closed, close the websocket, and retry if needed.""" async with self._ws_lock: # Already closed. - if not await ws_wrapper.is_open(): + if not ws_wrapper.is_open(): return await ws_wrapper.close() if should_retry and self._retry_connection_callback: @@ -302,7 +308,86 @@ async def close(self) -> None: # throw exception correctly. for stream in self._streams.values(): stream.close() - async with self._stream_lock: - self._streams.clear() + self._streams.clear() self._state = SessionState.CLOSED + + +async def check_to_close_session( + transport_id: str, + close_session_check_interval_ms: float, + get_state: Callable[[], SessionState], + get_current_time: Callable[[], Awaitable[float]], + get_close_session_after_time_secs: Callable[[], float | None], + do_close: Callable[[], Awaitable[None]], +) -> None: + while True: + await asyncio.sleep(close_session_check_interval_ms / 1000) + if get_state() in TerminalStates: + # already closing + return + # calculate the value now before comparing it so that there are no + # await points between the check and the comparison to avoid a TOCTOU + # race. + current_time = await get_current_time() + close_session_after_time_secs = get_close_session_after_time_secs() + if not close_session_after_time_secs: + continue + if current_time > close_session_after_time_secs: + logger.info("Grace period ended for %s, closing session", transport_id) + await do_close() + return + + +async def setup_heartbeat( + session_id: str, + heartbeat_ms: float, + heartbeats_until_dead: int, + get_state: Callable[[], SessionState], + get_closing_grace_period: Callable[[], float | None], + close_websocket: Callable[[], Awaitable[None]], + send_message: SendMessage, + increment_and_get_heartbeat_misses: Callable[[], int], +) -> None: + while True: + await asyncio.sleep(heartbeat_ms / 1000) + state = get_state() + if state != SessionState.ACTIVE: + logger.debug("Websocket is not connected, not sending heartbeat") + continue + if state in TerminalStates: + logger.debug( + "Session is closed, no need to send heartbeat, state : " + "%r close_session_after_this: %r", + {state}, + {get_closing_grace_period()}, + ) + # session is closing / closed, no need to send heartbeat anymore + return + try: + await send_message( + stream_id="heartbeat", + # TODO: make this a message class + # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 + payload={ + "ack": 0, + }, + control_flags=ACK_BIT, + procedure_name=None, + service_name=None, + span=None, + ) + + if increment_and_get_heartbeat_misses() > heartbeats_until_dead: + if get_closing_grace_period() is not None: + # already in grace period, no need to set again + continue + logger.info( + "%r closing websocket because of heartbeat misses", + session_id, + ) + await close_websocket() + continue + except FailedSendingMessageException: + # this is expected during websocket closed period + continue diff --git a/src/replit_river/websocket_wrapper.py b/src/replit_river/websocket_wrapper.py index 528fad7e..a1eb6a80 100644 --- a/src/replit_river/websocket_wrapper.py +++ b/src/replit_river/websocket_wrapper.py @@ -18,18 +18,15 @@ class WebsocketWrapper: def __init__(self, ws: WebSocketCommonProtocol) -> None: self.ws = ws self.ws_state = WsState.OPEN - self.ws_lock = asyncio.Lock() self.id = ws.id - async def is_open(self) -> bool: - async with self.ws_lock: - return self.ws_state == WsState.OPEN + def is_open(self) -> bool: + return self.ws_state == WsState.OPEN async def close(self) -> None: - async with self.ws_lock: - if self.ws_state == WsState.OPEN: - self.ws_state = WsState.CLOSING - task = asyncio.create_task(self.ws.close()) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) - self.ws_state = WsState.CLOSED + if self.ws_state == WsState.OPEN: + self.ws_state = WsState.CLOSING + task = asyncio.create_task(self.ws.close()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + self.ws_state = WsState.CLOSED diff --git a/tests/codegen/snapshot/codegen_snapshot_fixtures.py b/tests/codegen/snapshot/codegen_snapshot_fixtures.py index 1a343fc8..ef74a1fb 100644 --- a/tests/codegen/snapshot/codegen_snapshot_fixtures.py +++ b/tests/codegen/snapshot/codegen_snapshot_fixtures.py @@ -34,6 +34,7 @@ def file_opener(path: Path) -> TextIO: client_name=client_name, file_opener=file_opener, typed_dict_inputs=True, + method_filter=None, ) for path, file in files.items(): file.seek(0) diff --git a/tests/codegen/test_rpc.py b/tests/codegen/test_rpc.py index 9c2a5d8e..8ab82095 100644 --- a/tests/codegen/test_rpc.py +++ b/tests/codegen/test_rpc.py @@ -1,5 +1,6 @@ import asyncio import importlib +import os import shutil from datetime import timedelta from pathlib import Path @@ -19,20 +20,26 @@ @pytest.fixture(scope="session", autouse=True) def generate_rpc_client() -> None: - import tests.codegen.rpc.generated - - shutil.rmtree("tests/codegen/rpc/generated") + shutil.rmtree("tests/codegen/rpc/generated", ignore_errors=True) + os.makedirs("tests/codegen/rpc/generated") def file_opener(path: Path) -> TextIO: return open(path, "w") schema_to_river_client_codegen( - lambda: open("tests/codegen/rpc/schema.json"), - "tests/codegen/rpc/generated", - "RpcClient", - True, - file_opener, + read_schema=lambda: open("tests/codegen/rpc/schema.json"), + target_path="tests/codegen/rpc/generated", + client_name="RpcClient", + typed_dict_inputs=True, + file_opener=file_opener, + method_filter=None, ) + + +@pytest.fixture(scope="session", autouse=True) +def reload_rpc_import(generate_rpc_client: None) -> None: + import tests.codegen.rpc.generated + importlib.reload(tests.codegen.rpc.generated) diff --git a/tests/test_message_buffer.py b/tests/test_message_buffer.py index 7d330375..02a21ccb 100644 --- a/tests/test_message_buffer.py +++ b/tests/test_message_buffer.py @@ -35,7 +35,8 @@ async def test_message_buffer_backpressure() -> None: async def put_messages() -> None: for i in range(0, iterations): - await buffer.put(mock_transport_message(seq=i)) + await buffer.has_capacity() + buffer.put(mock_transport_message(seq=i)) await sync_events.put(None) background_puts = asyncio.create_task(put_messages()) @@ -55,10 +56,17 @@ async def test_message_buffer_close() -> None: is closed while the put operation is waiting for space in the buffer. """ buffer = MessageBuffer(max_num_messages=1) - await buffer.put(mock_transport_message(seq=1)) - background_put = asyncio.create_task(buffer.put(mock_transport_message(seq=1))) + await buffer.has_capacity() + buffer.put(mock_transport_message(seq=1)) + + async def bg_put(msg: TransportMessage) -> None: + await buffer.has_capacity() + buffer.put(msg) + + background_put = asyncio.create_task(bg_put(mock_transport_message(seq=1))) await buffer.close() with pytest.raises(MessageBufferClosedError): await background_put with pytest.raises(MessageBufferClosedError): - await buffer.put(mock_transport_message(seq=1)) + await buffer.has_capacity() + buffer.put(mock_transport_message(seq=1)) diff --git a/tests/test_seq_manager.py b/tests/test_seq_manager.py index cf53a3d7..fba7f58a 100644 --- a/tests/test_seq_manager.py +++ b/tests/test_seq_manager.py @@ -1,9 +1,7 @@ -import asyncio - import pytest from replit_river.seq_manager import ( - IgnoreMessageException, + IgnoreMessage, OutOfOrderMessageException, SeqManager, ) @@ -14,17 +12,17 @@ @pytest.mark.asyncio async def test_initial_sequence_and_ack_numbers(no_logging_error: NoErrors) -> None: manager = SeqManager() - assert await manager.get_seq() == 0, "Initial sequence number should be 0" - assert await manager.get_ack() == 0, "Initial acknowledgment number should be 0" + assert manager.get_seq() == 0, "Initial sequence number should be 0" + assert manager.get_ack() == 0, "Initial acknowledgment number should be 0" no_logging_error() @pytest.mark.asyncio async def test_sequence_number_increment(no_logging_error: NoErrors) -> None: manager = SeqManager() - initial_seq = await manager.get_seq_and_increment() + initial_seq = manager.get_seq_and_increment() assert initial_seq == 0, "Sequence number should start at 0" - new_seq = await manager.get_seq() + new_seq = manager.get_seq() assert new_seq == 1, "Sequence number should increment to 1" no_logging_error() @@ -33,42 +31,40 @@ async def test_sequence_number_increment(no_logging_error: NoErrors) -> None: async def test_message_reception(no_logging_error: NoErrors) -> None: manager = SeqManager() msg = transport_message(seq=0, ack=0, from_="client") - await manager.check_seq_and_update( + manager.check_seq_and_update( msg ) # No error should be raised for the correct sequence - assert await manager.get_ack() == 1, "Acknowledgment should be set to 1" + assert manager.get_ack() == 1, "Acknowledgment should be set to 1" # We assert no errors before we send out-of-order messages no_logging_error() # Test duplicate message - with pytest.raises(IgnoreMessageException): - await manager.check_seq_and_update(msg) + assert isinstance(manager.check_seq_and_update(msg), IgnoreMessage) # Test out of order message msg.seq = 2 with pytest.raises(OutOfOrderMessageException): - await manager.check_seq_and_update(msg) + manager.check_seq_and_update(msg) @pytest.mark.asyncio async def test_acknowledgment_setting(no_logging_error: NoErrors) -> None: manager = SeqManager() msg = transport_message(seq=0, ack=0, from_="client") - await manager.check_seq_and_update(msg) - assert await manager.get_ack() == 1, "Acknowledgment number should be updated" + manager.check_seq_and_update(msg) + assert manager.get_ack() == 1, "Acknowledgment number should be updated" no_logging_error() @pytest.mark.asyncio async def test_concurrent_access_to_sequence(no_logging_error: NoErrors) -> None: manager = SeqManager() - tasks = [manager.get_seq_and_increment() for _ in range(10)] - results = await asyncio.gather(*tasks) + results = [manager.get_seq_and_increment() for _ in range(10)] assert len(set(results)) == 10, ( "Each increment call should return a unique sequence number" ) - assert await manager.get_seq() == 10, ( + assert manager.get_seq() == 10, ( "Final sequence number should be 10 after 10 increments" ) no_logging_error()