From 1ae427caff643d0efc517aa8f649fb0767aab610 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 21:22:58 -0700 Subject: [PATCH 01/29] lint tests as well --- scripts/lint/src/lint/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/lint/src/lint/__init__.py b/scripts/lint/src/lint/__init__.py index f5b5c3d3..9d62a193 100644 --- a/scripts/lint/src/lint/__init__.py +++ b/scripts/lint/src/lint/__init__.py @@ -11,7 +11,8 @@ def raise_err(code: int) -> None: def main() -> None: fix = ["--fix"] if "--fix" in sys.argv else [] + watch = ["--watch"] if "--watch" in sys.argv else [] raise_err(os.system(" ".join(["ruff", "check", "src", "scripts", "tests"] + fix))) raise_err(os.system("ruff format src scripts tests")) - raise_err(os.system("mypy src")) - raise_err(os.system("pyright src")) + raise_err(os.system("mypy src tests")) + raise_err(os.system(" ".join(["pyright"] + watch + ["src", "tests"]))) From 76f9d3a2d180cea25e6eacd1612849a1eabf6058 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 22:31:51 -0700 Subject: [PATCH 02/29] Permit method filtering based on supplied file --- src/replit_river/codegen/client.py | 21 +++++++++++++++---- src/replit_river/codegen/run.py | 13 ++++++++++++ .../snapshot/codegen_snapshot_fixtures.py | 1 + tests/codegen/test_rpc.py | 11 +++++----- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 453b7a53..9aa18bf9 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -802,6 +802,7 @@ def generate_individual_service( schema_name: str, schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], + method_filter: set[str] | None, ) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] @@ -837,6 +838,8 @@ def __init__(self, client: river.Client[Any]): ), ] for name, procedure in schema.procedures.items(): + if method_filter and (schema_name + "." + name) in method_filter: + continue module_names = [ModuleName(name)] init_type: TypeExpression | None = None if procedure.init: @@ -1223,6 +1226,7 @@ def generate_river_client_module( client_name: str, schema_root: RiverSchema, typed_dict_inputs: bool, + method_filter: set[str] | None, ) -> dict[RenderedPath, FileContents]: files: dict[RenderedPath, FileContents] = {} @@ -1247,10 +1251,15 @@ def generate_river_client_module( ) for schema_name, schema in schema_root.services.items(): module_name, class_name, emitted_files = generate_individual_service( - schema_name, schema, input_base_class + schema_name, + schema, + input_base_class, + method_filter, ) - files.update(emitted_files) - modules.append((module_name, class_name)) + if emitted_files: + # Short-cut if we didn't actually emit anything + files.update(emitted_files) + modules.append((module_name, class_name)) main_contents = generate_common_client( client_name, handshake_type, handshake_chunks, modules @@ -1266,12 +1275,16 @@ def schema_to_river_client_codegen( client_name: str, typed_dict_inputs: bool, file_opener: Callable[[Path], TextIO], + method_filter: set[str] | None, ) -> None: """Generates the lines of a River module.""" with read_schema() as f: schemas = RiverSchemaFile(json.load(f)) for subpath, contents in generate_river_client_module( - client_name, schemas.root, typed_dict_inputs + client_name, + schemas.root, + typed_dict_inputs, + method_filter, ).items(): module_path = Path(target_path).joinpath(subpath) module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True) diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index c9ab6384..2eb9020b 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -1,5 +1,6 @@ import argparse import os.path +import pathlib from pathlib import Path from typing import TextIO @@ -38,9 +39,20 @@ def main() -> None: action="store_true", default=False, ) + client.add_argument( + "--method-filter", + help="Only generate a subset of the specified methods", + action="store", + type=pathlib.Path, + ) client.add_argument("schema", help="schema file") args = parser.parse_args() + method_filter: set[str] | None = None + if args.method_filter: + with open(args.method_filter) as handle: + method_filter = set(x.strip() for x in handle.readlines()) + if args.command == "server": proto_path = os.path.abspath(args.proto) target_directory = os.path.abspath(args.output) @@ -62,6 +74,7 @@ def file_opener(path: Path) -> TextIO: args.client_name, args.typed_dict_inputs, file_opener, + method_filter=method_filter, ) else: raise NotImplementedError(f"Unknown command {args.command}") diff --git a/tests/codegen/snapshot/codegen_snapshot_fixtures.py b/tests/codegen/snapshot/codegen_snapshot_fixtures.py index 1a343fc8..ef74a1fb 100644 --- a/tests/codegen/snapshot/codegen_snapshot_fixtures.py +++ b/tests/codegen/snapshot/codegen_snapshot_fixtures.py @@ -34,6 +34,7 @@ def file_opener(path: Path) -> TextIO: client_name=client_name, file_opener=file_opener, typed_dict_inputs=True, + method_filter=None, ) for path, file in files.items(): file.seek(0) diff --git a/tests/codegen/test_rpc.py b/tests/codegen/test_rpc.py index 9c2a5d8e..c5483432 100644 --- a/tests/codegen/test_rpc.py +++ b/tests/codegen/test_rpc.py @@ -27,11 +27,12 @@ def file_opener(path: Path) -> TextIO: return open(path, "w") schema_to_river_client_codegen( - lambda: open("tests/codegen/rpc/schema.json"), - "tests/codegen/rpc/generated", - "RpcClient", - True, - file_opener, + read_schema=lambda: open("tests/codegen/rpc/schema.json"), + target_path="tests/codegen/rpc/generated", + client_name="RpcClient", + typed_dict_inputs=True, + file_opener=file_opener, + method_filter=None, ) importlib.reload(tests.codegen.rpc.generated) From bbd204993da7bd306778b30bedeadd0c2810e58d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 21:27:18 -0700 Subject: [PATCH 03/29] Missing invocation of validate_python --- src/replit_river/codegen/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 9aa18bf9..ea4dec85 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -983,7 +983,7 @@ def __init__(self, client: river.Client[Any]): ) render_init_method = f"""\ lambda x: {render_type_expr(init_type_type_adapter_name)} - .validate_python + .validate_python(x) """ assert init_type is None or render_init_method, ( From 64b1b89d6e63e7904022f7bf368399aee2318857 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 23:02:42 -0700 Subject: [PATCH 04/29] inline add_msg_to_stream --- src/replit_river/client_session.py | 18 ++++++++++++++++-- src/replit_river/common_session.py | 24 +----------------------- src/replit_river/server_session.py | 18 ++++++++++++++++-- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 599ac5c3..656af6c2 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -11,7 +11,6 @@ from opentelemetry.trace import Span from websockets.exceptions import ConnectionClosed -from replit_river.common_session import add_msg_to_stream from replit_river.error_schema import ( ERROR_CODE_CANCEL, ERROR_CODE_STREAM_CLOSED, @@ -144,7 +143,22 @@ async def _handle_messages_from_ws(self) -> None: raise IgnoreMessageException( "no stream for message, ignoring" ) - await add_msg_to_stream(msg, stream) + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e + else: raise InvalidMessageException( "Client should not receive stream open bit" diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 7193733d..52271004 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -3,12 +3,10 @@ import logging from typing import Any, Awaitable, Callable, Protocol -from aiochannel import Channel, ChannelClosed from opentelemetry.trace import Span from replit_river.messages import FailedSendingMessageException -from replit_river.rpc import ACK_BIT, STREAM_CLOSED_BIT, TransportMessage -from replit_river.seq_manager import InvalidMessageException +from replit_river.rpc import ACK_BIT logger = logging.getLogger(__name__) @@ -113,23 +111,3 @@ async def check_to_close_session( logger.info("Grace period ended for %s, closing session", transport_id) await do_close() return - - -async def add_msg_to_stream( - msg: TransportMessage, - stream: Channel, -) -> None: - if ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - return - try: - await stream.put(msg.payload) - except ChannelClosed: - # The client is no longer interested in this stream, - # just drop the message. - pass - except RuntimeError as e: - raise InvalidMessageException(e) from e diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 868ff0fb..da29477e 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -7,7 +7,6 @@ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from websockets.exceptions import ConnectionClosed -from replit_river.common_session import add_msg_to_stream from replit_river.messages import ( FailedSendingMessageException, parse_transport_msg, @@ -143,7 +142,22 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: raise IgnoreMessageException( "no stream for message, ignoring" ) - await add_msg_to_stream(msg, stream) + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e + else: _stream = await self._open_stream_and_call_handler(msg, tg) if not stream: From 90016eedc0919a9f058381aea98b72a9d88ca1aa Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 23:06:54 -0700 Subject: [PATCH 05/29] Moving STREAM_CLOSED_BIT into v1 session objects for clarity in preparation for v2 --- src/replit_river/client_session.py | 4 +++- src/replit_river/rpc.py | 1 - src/replit_river/server_session.py | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 656af6c2..04efd22a 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -33,7 +33,6 @@ from .rpc import ( ACK_BIT, - STREAM_CLOSED_BIT, STREAM_OPEN_BIT, ErrorType, InitType, @@ -44,6 +43,9 @@ logger = logging.getLogger(__name__) +STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2 + + class ClientSession(Session): def __init__( self, diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index 0d1bd4d1..f00db903 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -48,7 +48,6 @@ ] ACK_BIT = 0x0001 STREAM_OPEN_BIT = 0x0002 -STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2 # these codes are retriable # if the server sends a response with one of these codes, diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index da29477e..98d9c9a6 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -21,14 +21,13 @@ from .rpc import ( ACK_BIT, - STREAM_CLOSED_BIT, STREAM_OPEN_BIT, GenericRpcHandlerBuilder, TransportMessage, TransportMessageTracingSetter, ) -logger = logging.getLogger(__name__) +STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2 logger = logging.getLogger(__name__) From 7d3b6baa21c69ef581021c72b8b9cdaba9a6f346 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 21:38:34 -0700 Subject: [PATCH 06/29] parse_transport_msg should just return a value --- src/replit_river/client_session.py | 5 ++++- src/replit_river/client_transport.py | 10 +++++----- src/replit_river/messages.py | 10 ++-------- src/replit_river/server_session.py | 5 ++++- src/replit_river/server_transport.py | 4 +++- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 04efd22a..845cb370 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -124,7 +124,10 @@ async def _handle_messages_from_ws(self) -> None: if not await ws_wrapper.is_open(): # We should not process messages if the websocket is closed. break - msg = parse_transport_msg(message, self._transport_options) + msg = parse_transport_msg(message) + if isinstance(msg, str): + logger.debug("Ignoring transport message", exc_info=True) + continue logger.debug(f"{self._transport_id} got a message %r", msg) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 56158fcf..6aa8aa09 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -34,7 +34,6 @@ TransportMessage, ) from replit_river.seq_manager import ( - IgnoreMessageException, InvalidMessageException, ) from replit_river.session import Session @@ -293,10 +292,11 @@ async def _get_handshake_response_msg( "Handshake failed, conn closed while waiting for response", ) from e try: - return parse_transport_msg(data, self._transport_options) - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue + msg = parse_transport_msg(data) + if isinstance(msg, str): + logger.debug("Ignoring transport message", exc_info=True) + continue + return msg except InvalidMessageException as e: raise RiverException( ERROR_HANDSHAKE, diff --git a/src/replit_river/messages.py b/src/replit_river/messages.py index fc8e608a..f3aefb4c 100644 --- a/src/replit_river/messages.py +++ b/src/replit_river/messages.py @@ -13,10 +13,8 @@ TransportMessage, ) from replit_river.seq_manager import ( - IgnoreMessageException, InvalidMessageException, ) -from replit_river.transport_options import TransportOptions logger = logging.getLogger(__name__) @@ -62,13 +60,9 @@ def formatted_bytes(message: bytes) -> str: return " ".join(f"{b:02x}" for b in message) -def parse_transport_msg( - message: str | bytes, transport_options: TransportOptions -) -> TransportMessage: +def parse_transport_msg(message: str | bytes) -> TransportMessage | str: if isinstance(message, str): - raise IgnoreMessageException( - f"ignored a message beacuse it was a text frame: {message}" - ) + return message try: # :param int timestamp: # Control how timestamp type is unpacked: diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 98d9c9a6..ddf55524 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -120,7 +120,10 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: if not await ws_wrapper.is_open(): # We should not process messages if the websocket is closed. break - msg = parse_transport_msg(message, self._transport_options) + msg = parse_transport_msg(message) + if isinstance(msg, str): + logger.debug("Ignoring transport message", exc_info=True) + continue logger.debug(f"{self._transport_id} got a message %r", msg) diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 88e848a7..a1edc68c 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -72,7 +72,9 @@ async def handshake_to_get_session( ) -> ServerSession: async for message in websocket: try: - msg = parse_transport_msg(message, self._transport_options) + msg = parse_transport_msg(message) + if isinstance(msg, str): + continue ( handshake_request, handshake_response, From 4d860455df9db0b34ffd716ba4547963a276fce9 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 21:43:24 -0700 Subject: [PATCH 07/29] Switch from IgnoreMessageException to IgnoreMessage return value --- src/replit_river/client_session.py | 21 +++++++++------- src/replit_river/seq_manager.py | 20 +++++++-------- src/replit_river/server_session.py | 37 +++++++++++++++++----------- src/replit_river/server_transport.py | 3 --- tests/test_seq_manager.py | 5 ++-- 5 files changed, 46 insertions(+), 40 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 845cb370..e9729e3e 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -2,7 +2,7 @@ import logging from collections.abc import AsyncIterable from datetime import timedelta -from typing import Any, AsyncGenerator, Callable, Coroutine +from typing import Any, AsyncGenerator, Callable, Coroutine, assert_never import nanoid # type: ignore import websockets @@ -24,7 +24,7 @@ parse_transport_msg, ) from replit_river.seq_manager import ( - IgnoreMessageException, + IgnoreMessage, InvalidMessageException, OutOfOrderMessageException, ) @@ -132,7 +132,14 @@ async def _handle_messages_from_ws(self) -> None: logger.debug(f"{self._transport_id} got a message %r", msg) # Update bookkeeping - await self._seq_manager.check_seq_and_update(msg) + match await self._seq_manager.check_seq_and_update(msg): + case IgnoreMessage(): + continue + case None: + pass + case other: + assert_never(other) + await self._buffer.remove_old_messages( self._seq_manager.receiver_ack, ) @@ -145,9 +152,8 @@ async def _handle_messages_from_ws(self) -> None: if msg.controlFlags & STREAM_OPEN_BIT == 0: if not stream: logger.warning("no stream for %s", msg.streamId) - raise IgnoreMessageException( - "no stream for message, ignoring" - ) + continue + if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 and msg.payload.get("type", None) == "CLOSE" @@ -174,9 +180,6 @@ async def _handle_messages_from_ws(self) -> None: stream.close() async with self._stream_lock: del self._streams[msg.streamId] - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await ws_wrapper.close() diff --git a/src/replit_river/seq_manager.py b/src/replit_river/seq_manager.py index fa75b448..c75d4d97 100644 --- a/src/replit_river/seq_manager.py +++ b/src/replit_river/seq_manager.py @@ -1,17 +1,12 @@ import asyncio import logging +from dataclasses import dataclass from replit_river.rpc import TransportMessage logger = logging.getLogger(__name__) -class IgnoreMessageException(Exception): - """Exception to ignore a transport message, but good to continue.""" - - pass - - class InvalidMessageException(Exception): """Error processing a transport message, should raise a exception.""" @@ -34,6 +29,11 @@ class SessionStateMismatchException(Exception): pass +@dataclass +class IgnoreMessage: + pass + + class SeqManager: """Manages the sequence number and ack number for a connection.""" @@ -68,14 +68,11 @@ async def get_ack(self) -> int: async with self._ack_lock: return self.ack - async def check_seq_and_update(self, msg: TransportMessage) -> None: + async def check_seq_and_update(self, msg: TransportMessage) -> IgnoreMessage | None: async with self._ack_lock: if msg.seq != self.ack: if msg.seq < self.ack: - raise IgnoreMessageException( - f"{msg.from_} received duplicate msg, got {msg.seq}" - f" expected {self.ack}" - ) + return IgnoreMessage() else: logger.warn( f"Out of order message received got {msg.seq} expected " @@ -88,6 +85,7 @@ async def check_seq_and_update(self, msg: TransportMessage) -> None: ) self.receiver_ack = msg.ack await self._set_ack(msg.seq + 1) + return None async def _set_ack(self, new_ack: int) -> int: async with self._ack_lock: diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index ddf55524..085230c1 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any, Callable, Coroutine +from typing import Any, Callable, Coroutine, assert_never import websockets from aiochannel import Channel, ChannelClosed @@ -12,7 +12,7 @@ parse_transport_msg, ) from replit_river.seq_manager import ( - IgnoreMessageException, + IgnoreMessage, InvalidMessageException, OutOfOrderMessageException, ) @@ -128,7 +128,13 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: logger.debug(f"{self._transport_id} got a message %r", msg) # Update bookkeeping - await self._seq_manager.check_seq_and_update(msg) + match await self._seq_manager.check_seq_and_update(msg): + case IgnoreMessage(): + continue + case None: + pass + case other: + assert_never(other) await self._buffer.remove_old_messages( self._seq_manager.receiver_ack, ) @@ -141,9 +147,8 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: if msg.controlFlags & STREAM_OPEN_BIT == 0: if not stream: logger.warning("no stream for %s", msg.streamId) - raise IgnoreMessageException( - "no stream for message, ignoring" - ) + continue + if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 and msg.payload.get("type", None) == "CLOSE" @@ -162,6 +167,8 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: else: _stream = await self._open_stream_and_call_handler(msg, tg) + if isinstance(_stream, IgnoreMessage): + continue if not stream: async with self._stream_lock: self._streams[msg.streamId] = _stream @@ -172,9 +179,6 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: stream.close() async with self._stream_lock: del self._streams[msg.streamId] - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await ws_wrapper.close() @@ -190,17 +194,22 @@ async def _open_stream_and_call_handler( self, msg: TransportMessage, tg: asyncio.TaskGroup | None, - ) -> Channel: + ) -> Channel | IgnoreMessage: if not msg.serviceName or not msg.procedureName: - raise IgnoreMessageException( - f"Service name or procedure name is missing in the message {msg}" + logger.warning( + "Service name or procedure name is missing in the message %r", + msg, ) + return IgnoreMessage() key = (msg.serviceName, msg.procedureName) handler = self._handlers.get(key, None) if not handler: - raise IgnoreMessageException( - f"No handler for {key} handlers : {self._handlers.keys()}" + logger.warning( + "No handler for %r handlers: %r", + key, + self._handlers.keys(), ) + return IgnoreMessage() method_type, handler_func = handler is_streaming_output = method_type in ( "subscription-stream", # subscription diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index a1edc68c..9c4c9c5a 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -25,7 +25,6 @@ TransportMessage, ) from replit_river.seq_manager import ( - IgnoreMessageException, InvalidMessageException, SessionStateMismatchException, ) @@ -79,8 +78,6 @@ async def handshake_to_get_session( handshake_request, handshake_response, ) = await self._establish_handshake(msg, websocket) - except IgnoreMessageException: - continue except InvalidMessageException as e: error_msg = "Got invalid transport message, closing connection" raise InvalidMessageException(error_msg) from e diff --git a/tests/test_seq_manager.py b/tests/test_seq_manager.py index cf53a3d7..323e0958 100644 --- a/tests/test_seq_manager.py +++ b/tests/test_seq_manager.py @@ -3,7 +3,7 @@ import pytest from replit_river.seq_manager import ( - IgnoreMessageException, + IgnoreMessage, OutOfOrderMessageException, SeqManager, ) @@ -42,8 +42,7 @@ async def test_message_reception(no_logging_error: NoErrors) -> None: no_logging_error() # Test duplicate message - with pytest.raises(IgnoreMessageException): - await manager.check_seq_and_update(msg) + assert isinstance(await manager.check_seq_and_update(msg), IgnoreMessage) # Test out of order message msg.seq = 2 From 607bacfe4e6e36eb54b02d939eecfe13f0e5f5c3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 21:57:03 -0700 Subject: [PATCH 08/29] Missing ws.close() --- src/replit_river/client_transport.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 6aa8aa09..d02bc9e8 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -115,6 +115,8 @@ async def get_or_create_session(self) -> ClientSession: return existing_session else: logger.info("Closing stale session %s", existing_session.session_id) + await new_ws.close() # NB(dstewart): This wasn't there in the + # v1 transport, were we just leaking WS? await existing_session.close() return await self._create_new_session() From c2ecb772fe122ea1adb3c709b587a600afca5e2c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 21:56:55 -0700 Subject: [PATCH 09/29] Distribute PROTOCOL_VERSION through the different files it belongs in --- src/replit_river/client_transport.py | 3 ++- src/replit_river/messages.py | 3 --- src/replit_river/server_transport.py | 3 ++- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index d02bc9e8..04215811 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -19,7 +19,6 @@ RiverException, ) from replit_river.messages import ( - PROTOCOL_VERSION, FailedSendingMessageException, WebsocketClosedException, parse_transport_msg, @@ -43,6 +42,8 @@ UriAndMetadata, ) +PROTOCOL_VERSION = "v1.1" + logger = logging.getLogger(__name__) diff --git a/src/replit_river/messages.py b/src/replit_river/messages.py index f3aefb4c..9cdf324a 100644 --- a/src/replit_river/messages.py +++ b/src/replit_river/messages.py @@ -27,9 +27,6 @@ class FailedSendingMessageException(Exception): pass -PROTOCOL_VERSION = "v1.1" - - async def send_transport_message( msg: TransportMessage, ws: WebSocketCommonProtocol, diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 9c4c9c5a..05a36d6d 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -11,7 +11,6 @@ from websockets.exceptions import ConnectionClosed from replit_river.messages import ( - PROTOCOL_VERSION, FailedSendingMessageException, WebsocketClosedException, parse_transport_msg, @@ -32,6 +31,8 @@ from replit_river.session import Session from replit_river.transport_options import TransportOptions +PROTOCOL_VERSION = "v1.1" + logger = logging.getLogger(__name__) From 0775f21bda5b23f1a8e948c8b5820c43a36f2960 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:00:05 -0700 Subject: [PATCH 10/29] Moving setup_heartbeat to session.py --- src/replit_river/common_session.py | 55 ------------------------------ src/replit_river/session.py | 55 +++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 56 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 52271004..1f3caf3a 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -5,9 +5,6 @@ from opentelemetry.trace import Span -from replit_river.messages import FailedSendingMessageException -from replit_river.rpc import ACK_BIT - logger = logging.getLogger(__name__) @@ -35,58 +32,6 @@ class SessionState(enum.Enum): CLOSED = 2 -async def setup_heartbeat( - session_id: str, - heartbeat_ms: float, - heartbeats_until_dead: int, - get_state: Callable[[], SessionState], - get_closing_grace_period: Callable[[], float | None], - close_websocket: Callable[[], Awaitable[None]], - send_message: SendMessage, - increment_and_get_heartbeat_misses: Callable[[], int], -) -> None: - logger.debug("Start heartbeat") - while True: - await asyncio.sleep(heartbeat_ms / 1000) - state = get_state() - if state != SessionState.ACTIVE: - logger.debug( - "Session is closed, no need to send heartbeat, state : " - "%r close_session_after_this: %r", - {state}, - {get_closing_grace_period()}, - ) - # session is closing / closed, no need to send heartbeat anymore - return - try: - await send_message( - stream_id="heartbeat", - # TODO: make this a message class - # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 - payload={ - "ack": 0, - }, - control_flags=ACK_BIT, - procedure_name=None, - service_name=None, - span=None, - ) - - if increment_and_get_heartbeat_misses() > heartbeats_until_dead: - if get_closing_grace_period() is not None: - # already in grace period, no need to set again - continue - logger.info( - "%r closing websocket because of heartbeat misses", - session_id, - ) - await close_websocket() - continue - except FailedSendingMessageException: - # this is expected during websocket closed period - continue - - async def check_to_close_session( transport_id: str, close_session_check_interval_ms: float, diff --git a/src/replit_river/session.py b/src/replit_river/session.py index d908bdda..68a4f884 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -9,9 +9,9 @@ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from replit_river.common_session import ( + SendMessage, SessionState, check_to_close_session, - setup_heartbeat, ) from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( @@ -27,6 +27,7 @@ from replit_river.websocket_wrapper import WebsocketWrapper from .rpc import ( + ACK_BIT, TransportMessage, TransportMessageTracingSetter, ) @@ -306,3 +307,55 @@ async def close(self) -> None: self._streams.clear() self._state = SessionState.CLOSED + + +async def setup_heartbeat( + session_id: str, + heartbeat_ms: float, + heartbeats_until_dead: int, + get_state: Callable[[], SessionState], + get_closing_grace_period: Callable[[], float | None], + close_websocket: Callable[[], Awaitable[None]], + send_message: SendMessage, + increment_and_get_heartbeat_misses: Callable[[], int], +) -> None: + logger.debug("Start heartbeat") + while True: + await asyncio.sleep(heartbeat_ms / 1000) + state = get_state() + if state != SessionState.ACTIVE: + logger.debug( + "Session is closed, no need to send heartbeat, state : " + "%r close_session_after_this: %r", + {state}, + {get_closing_grace_period()}, + ) + # session is closing / closed, no need to send heartbeat anymore + return + try: + await send_message( + stream_id="heartbeat", + # TODO: make this a message class + # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 + payload={ + "ack": 0, + }, + control_flags=ACK_BIT, + procedure_name=None, + service_name=None, + span=None, + ) + + if increment_and_get_heartbeat_misses() > heartbeats_until_dead: + if get_closing_grace_period() is not None: + # already in grace period, no need to set again + continue + logger.info( + "%r closing websocket because of heartbeat misses", + session_id, + ) + await close_websocket() + continue + except FailedSendingMessageException: + # this is expected during websocket closed period + continue From 84a480bc019402ecdc49b8dbb64cddcfd3bb58b3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:03:03 -0700 Subject: [PATCH 11/29] Representing a richer tapestry of SessionState --- src/replit_river/common_session.py | 19 ++++++++++++++----- src/replit_river/session.py | 9 ++++++--- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 1f3caf3a..45cef88d 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -24,13 +24,20 @@ async def __call__( class SessionState(enum.Enum): """The state a session can be in. - Can only transition from ACTIVE to CLOSING to CLOSED. + Valid transitions: + - NO_CONNECTION -> {ACTIVE} + - ACTIVE -> {NO_CONNECTION, CLOSING} + - CLOSING -> {CLOSED} + - CLOSED -> {} """ - ACTIVE = 0 - CLOSING = 1 - CLOSED = 2 + NO_CONNECTION = 0 + ACTIVE = 1 + CLOSING = 2 + CLOSED = 3 +ConnectingStates = set([SessionState.NO_CONNECTION]) +TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) async def check_to_close_session( transport_id: str, @@ -42,9 +49,11 @@ async def check_to_close_session( ) -> None: while True: await asyncio.sleep(close_session_check_interval_ms / 1000) - if get_state() != SessionState.ACTIVE: + + if get_state() in TerminalStates: # already closing return + # calculate the value now before comparing it so that there are no # await points between the check and the comparison to avoid a TOCTOU # race. diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 68a4f884..3f50f61b 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -12,6 +12,7 @@ SendMessage, SessionState, check_to_close_session, + TerminalStates, ) from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( @@ -62,7 +63,7 @@ def __init__( self.session_id = session_id self._transport_options = transport_options - # session state, only modified during closing + # session state self._state = SessionState.ACTIVE self._state_lock = asyncio.Lock() self._close_session_callback = close_session_callback @@ -319,11 +320,13 @@ async def setup_heartbeat( send_message: SendMessage, increment_and_get_heartbeat_misses: Callable[[], int], ) -> None: - logger.debug("Start heartbeat") while True: await asyncio.sleep(heartbeat_ms / 1000) state = get_state() - if state != SessionState.ACTIVE: + if state == SessionState.CONNECTING: + logger.debug("Websocket is not connected, not sending heartbeat") + continue + if state in TerminalStates: logger.debug( "Session is closed, no need to send heartbeat, state : " "%r close_session_after_this: %r", From dabc4b8e4a8de6175b38892e0b6013b88cc139d9 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:05:32 -0700 Subject: [PATCH 12/29] We have nanoid types --- src/replit_river/client_session.py | 2 +- src/replit_river/server_transport.py | 2 +- src/replit_river/session.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index e9729e3e..86379185 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -4,7 +4,7 @@ from datetime import timedelta from typing import Any, AsyncGenerator, Callable, Coroutine, assert_never -import nanoid # type: ignore +import nanoid import websockets from aiochannel import Channel from aiochannel.errors import ChannelClosed diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 05a36d6d..7b15e36e 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -2,7 +2,7 @@ import logging from typing import Any -import nanoid # type: ignore # type: ignore +import nanoid from pydantic import ValidationError from websockets import ( WebSocketCommonProtocol, diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 3f50f61b..bad65693 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -2,7 +2,7 @@ import logging from typing import Any, Awaitable, Callable, Coroutine -import nanoid # type: ignore +import nanoid import websockets from aiochannel import Channel from opentelemetry.trace import Span, use_span From e6dddb4b7930000f394e6e0e310fd4d92c2336ac Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:06:31 -0700 Subject: [PATCH 13/29] Move check_to_close_session over to session --- src/replit_river/common_session.py | 26 -------------------------- src/replit_river/session.py | 27 ++++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 45cef88d..e16a6540 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -39,29 +39,3 @@ class SessionState(enum.Enum): ConnectingStates = set([SessionState.NO_CONNECTION]) TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) -async def check_to_close_session( - transport_id: str, - close_session_check_interval_ms: float, - get_state: Callable[[], SessionState], - get_current_time: Callable[[], Awaitable[float]], - get_close_session_after_time_secs: Callable[[], float | None], - do_close: Callable[[], Awaitable[None]], -) -> None: - while True: - await asyncio.sleep(close_session_check_interval_ms / 1000) - - if get_state() in TerminalStates: - # already closing - return - - # calculate the value now before comparing it so that there are no - # await points between the check and the comparison to avoid a TOCTOU - # race. - current_time = await get_current_time() - close_session_after_time_secs = get_close_session_after_time_secs() - if not close_session_after_time_secs: - continue - if current_time > close_session_after_time_secs: - logger.info("Grace period ended for %s, closing session", transport_id) - await do_close() - return diff --git a/src/replit_river/session.py b/src/replit_river/session.py index bad65693..4c2eeb38 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -11,7 +11,6 @@ from replit_river.common_session import ( SendMessage, SessionState, - check_to_close_session, TerminalStates, ) from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError @@ -310,6 +309,32 @@ async def close(self) -> None: self._state = SessionState.CLOSED +async def check_to_close_session( + transport_id: str, + close_session_check_interval_ms: float, + get_state: Callable[[], SessionState], + get_current_time: Callable[[], Awaitable[float]], + get_close_session_after_time_secs: Callable[[], float | None], + do_close: Callable[[], Awaitable[None]], +) -> None: + while True: + await asyncio.sleep(close_session_check_interval_ms / 1000) + if get_state() in TerminalStates: + # already closing + return + # calculate the value now before comparing it so that there are no + # await points between the check and the comparison to avoid a TOCTOU + # race. + current_time = await get_current_time() + close_session_after_time_secs = get_close_session_after_time_secs() + if not close_session_after_time_secs: + continue + if current_time > close_session_after_time_secs: + logger.info("Grace period ended for %s, closing session", transport_id) + await do_close() + return + + async def setup_heartbeat( session_id: str, heartbeat_ms: float, From 5c3b91af6d9e3b2719addae24f977acbb91deb7e Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:09:28 -0700 Subject: [PATCH 14/29] Turns out none of this was async anyhow --- src/replit_river/client_session.py | 4 +- src/replit_river/common_session.py | 5 +- src/replit_river/seq_manager.py | 73 +++++++++++---------------- src/replit_river/server_session.py | 4 +- src/replit_river/session.py | 18 +++---- src/replit_river/websocket_wrapper.py | 19 +++---- tests/test_seq_manager.py | 27 +++++----- 7 files changed, 64 insertions(+), 86 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 86379185..fc5ef7f0 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -121,7 +121,7 @@ async def _handle_messages_from_ws(self) -> None: ws_wrapper = self._ws_wrapper async for message in ws_wrapper.ws: try: - if not await ws_wrapper.is_open(): + if not ws_wrapper.is_open(): # We should not process messages if the websocket is closed. break msg = parse_transport_msg(message) @@ -132,7 +132,7 @@ async def _handle_messages_from_ws(self) -> None: logger.debug(f"{self._transport_id} got a message %r", msg) # Update bookkeeping - match await self._seq_manager.check_seq_and_update(msg): + match self._seq_manager.check_seq_and_update(msg): case IgnoreMessage(): continue case None: diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index e16a6540..2325492e 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -1,7 +1,6 @@ -import asyncio import enum import logging -from typing import Any, Awaitable, Callable, Protocol +from typing import Any, Protocol from opentelemetry.trace import Span @@ -36,6 +35,6 @@ class SessionState(enum.Enum): CLOSING = 2 CLOSED = 3 + ConnectingStates = set([SessionState.NO_CONNECTION]) TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) - diff --git a/src/replit_river/seq_manager.py b/src/replit_river/seq_manager.py index c75d4d97..8a2f6798 100644 --- a/src/replit_river/seq_manager.py +++ b/src/replit_river/seq_manager.py @@ -1,4 +1,3 @@ -import asyncio import logging from dataclasses import dataclass @@ -40,54 +39,40 @@ class SeqManager: def __init__( self, ) -> None: - self._seq_lock = asyncio.Lock() self.seq = 0 - self._ack_lock = asyncio.Lock() self.ack = 0 self.receiver_ack = 0 - async def get_seq_and_increment(self) -> int: + def get_seq_and_increment(self) -> int: """Get the current sequence number and increment it. This removes one lock acquire than get_seq and increment_seq separately. """ - async with self._seq_lock: - current_value = self.seq - self.seq += 1 - return current_value - - async def increment_seq(self) -> int: - async with self._seq_lock: - self.seq += 1 - return self.seq - - async def get_seq(self) -> int: - async with self._seq_lock: - return self.seq - - async def get_ack(self) -> int: - async with self._ack_lock: - return self.ack - - async def check_seq_and_update(self, msg: TransportMessage) -> IgnoreMessage | None: - async with self._ack_lock: - if msg.seq != self.ack: - if msg.seq < self.ack: - return IgnoreMessage() - else: - logger.warn( - f"Out of order message received got {msg.seq} expected " - f"{self.ack}" - ) - - raise OutOfOrderMessageException( - f"Out of order message received got {msg.seq} expected " - f"{self.ack}" - ) - self.receiver_ack = msg.ack - await self._set_ack(msg.seq + 1) + current_value = self.seq + self.seq += 1 + return current_value + + def increment_seq(self) -> int: + self.seq += 1 + return self.seq + + def get_seq(self) -> int: + return self.seq + + def get_ack(self) -> int: + return self.ack + + def check_seq_and_update(self, msg: TransportMessage) -> IgnoreMessage | None: + if msg.seq != self.ack: + if msg.seq < self.ack: + return IgnoreMessage() + else: + logger.warning( + f"Out of order message received got {msg.seq} expected {self.ack}" + ) + + raise OutOfOrderMessageException( + f"Out of order message received got {msg.seq} expected {self.ack}" + ) + self.receiver_ack = msg.ack + self.ack = msg.seq + 1 return None - - async def _set_ack(self, new_ack: int) -> int: - async with self._ack_lock: - self.ack = new_ack - return self.ack diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 085230c1..4af7852f 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -117,7 +117,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: ws_wrapper = self._ws_wrapper async for message in ws_wrapper.ws: try: - if not await ws_wrapper.is_open(): + if not ws_wrapper.is_open(): # We should not process messages if the websocket is closed. break msg = parse_transport_msg(message) @@ -128,7 +128,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: logger.debug(f"{self._transport_id} got a message %r", msg) # Update bookkeeping - match await self._seq_manager.check_seq_and_update(msg): + match self._seq_manager.check_seq_and_update(msg): case IgnoreMessage(): continue case None: diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 4c2eeb38..00ffcd27 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -121,7 +121,7 @@ async def is_session_open(self) -> bool: async def is_websocket_open(self) -> bool: async with self._ws_lock: - return await self._ws_wrapper.is_open() + return self._ws_wrapper.is_open() async def _begin_close_session_countdown(self) -> None: """Begin the countdown to close session, this should be called when @@ -194,18 +194,18 @@ async def _send_transport_message( async def get_next_expected_seq(self) -> int: """Get the next expected sequence number from the server.""" - return await self._seq_manager.get_ack() + return self._seq_manager.get_ack() async def get_next_sent_seq(self) -> int: """Get the next sequence number that the client will send.""" nextMessage = await self._buffer.peek() if nextMessage: return nextMessage.seq - return await self._seq_manager.get_seq() + return self._seq_manager.get_seq() async def get_next_expected_ack(self) -> int: """Get the next expected ack that the client expects.""" - return await self._seq_manager.get_seq() + return self._seq_manager.get_seq() async def send_message( self, @@ -225,8 +225,8 @@ async def send_message( id=nanoid.generate(), from_=self._transport_id, # type: ignore to=self._to_id, - seq=await self._seq_manager.get_seq_and_increment(), - ack=await self._seq_manager.get_ack(), + seq=self._seq_manager.get_seq_and_increment(), + ack=self._seq_manager.get_ack(), controlFlags=control_flags, payload=payload, serviceName=service_name, @@ -245,7 +245,7 @@ async def send_message( # The session is closed and is no longer accepting new messages. return async with self._ws_lock: - if not await self._ws_wrapper.is_open(): + if not self._ws_wrapper.is_open(): # If the websocket is closed, we should not send the message # and wait for the retry from the buffer. return @@ -271,7 +271,7 @@ async def close_websocket( """Mark the websocket as closed, close the websocket, and retry if needed.""" async with self._ws_lock: # Already closed. - if not await ws_wrapper.is_open(): + if not ws_wrapper.is_open(): return await ws_wrapper.close() if should_retry and self._retry_connection_callback: @@ -348,7 +348,7 @@ async def setup_heartbeat( while True: await asyncio.sleep(heartbeat_ms / 1000) state = get_state() - if state == SessionState.CONNECTING: + if state != SessionState.ACTIVE: logger.debug("Websocket is not connected, not sending heartbeat") continue if state in TerminalStates: diff --git a/src/replit_river/websocket_wrapper.py b/src/replit_river/websocket_wrapper.py index 528fad7e..a1eb6a80 100644 --- a/src/replit_river/websocket_wrapper.py +++ b/src/replit_river/websocket_wrapper.py @@ -18,18 +18,15 @@ class WebsocketWrapper: def __init__(self, ws: WebSocketCommonProtocol) -> None: self.ws = ws self.ws_state = WsState.OPEN - self.ws_lock = asyncio.Lock() self.id = ws.id - async def is_open(self) -> bool: - async with self.ws_lock: - return self.ws_state == WsState.OPEN + def is_open(self) -> bool: + return self.ws_state == WsState.OPEN async def close(self) -> None: - async with self.ws_lock: - if self.ws_state == WsState.OPEN: - self.ws_state = WsState.CLOSING - task = asyncio.create_task(self.ws.close()) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) - self.ws_state = WsState.CLOSED + if self.ws_state == WsState.OPEN: + self.ws_state = WsState.CLOSING + task = asyncio.create_task(self.ws.close()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + self.ws_state = WsState.CLOSED diff --git a/tests/test_seq_manager.py b/tests/test_seq_manager.py index 323e0958..fba7f58a 100644 --- a/tests/test_seq_manager.py +++ b/tests/test_seq_manager.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from replit_river.seq_manager import ( @@ -14,17 +12,17 @@ @pytest.mark.asyncio async def test_initial_sequence_and_ack_numbers(no_logging_error: NoErrors) -> None: manager = SeqManager() - assert await manager.get_seq() == 0, "Initial sequence number should be 0" - assert await manager.get_ack() == 0, "Initial acknowledgment number should be 0" + assert manager.get_seq() == 0, "Initial sequence number should be 0" + assert manager.get_ack() == 0, "Initial acknowledgment number should be 0" no_logging_error() @pytest.mark.asyncio async def test_sequence_number_increment(no_logging_error: NoErrors) -> None: manager = SeqManager() - initial_seq = await manager.get_seq_and_increment() + initial_seq = manager.get_seq_and_increment() assert initial_seq == 0, "Sequence number should start at 0" - new_seq = await manager.get_seq() + new_seq = manager.get_seq() assert new_seq == 1, "Sequence number should increment to 1" no_logging_error() @@ -33,41 +31,40 @@ async def test_sequence_number_increment(no_logging_error: NoErrors) -> None: async def test_message_reception(no_logging_error: NoErrors) -> None: manager = SeqManager() msg = transport_message(seq=0, ack=0, from_="client") - await manager.check_seq_and_update( + manager.check_seq_and_update( msg ) # No error should be raised for the correct sequence - assert await manager.get_ack() == 1, "Acknowledgment should be set to 1" + assert manager.get_ack() == 1, "Acknowledgment should be set to 1" # We assert no errors before we send out-of-order messages no_logging_error() # Test duplicate message - assert isinstance(await manager.check_seq_and_update(msg), IgnoreMessage) + assert isinstance(manager.check_seq_and_update(msg), IgnoreMessage) # Test out of order message msg.seq = 2 with pytest.raises(OutOfOrderMessageException): - await manager.check_seq_and_update(msg) + manager.check_seq_and_update(msg) @pytest.mark.asyncio async def test_acknowledgment_setting(no_logging_error: NoErrors) -> None: manager = SeqManager() msg = transport_message(seq=0, ack=0, from_="client") - await manager.check_seq_and_update(msg) - assert await manager.get_ack() == 1, "Acknowledgment number should be updated" + manager.check_seq_and_update(msg) + assert manager.get_ack() == 1, "Acknowledgment number should be updated" no_logging_error() @pytest.mark.asyncio async def test_concurrent_access_to_sequence(no_logging_error: NoErrors) -> None: manager = SeqManager() - tasks = [manager.get_seq_and_increment() for _ in range(10)] - results = await asyncio.gather(*tasks) + results = [manager.get_seq_and_increment() for _ in range(10)] assert len(set(results)) == 10, ( "Each increment call should return a unique sequence number" ) - assert await manager.get_seq() == 10, ( + assert manager.get_seq() == 10, ( "Final sequence number should be 10 after 10 increments" ) no_logging_error() From d6650e7e9316c6919a39ab8ee9ce22908a85c084 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:26:20 -0700 Subject: [PATCH 15/29] Flattening unnecessary wrapper function --- src/replit_river/session.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 00ffcd27..99fa407b 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -158,9 +158,10 @@ async def replace_with_new_websocket( buffered_messages = list(self._buffer.buffer) for msg in buffered_messages: try: - await self._send_transport_message( + await send_transport_message( msg, new_ws, + self._begin_close_session_countdown, ) except WebsocketClosedException: logger.info( @@ -178,20 +179,6 @@ def _reset_session_close_countdown(self) -> None: self._heartbeat_misses = 0 self._close_session_after_time_secs = None - async def _send_transport_message( - self, - msg: TransportMessage, - websocket: websockets.WebSocketCommonProtocol, - ) -> None: - try: - await send_transport_message( - msg, websocket, self._begin_close_session_countdown - ) - except WebsocketClosedException as e: - raise e - except FailedSendingMessageException as e: - raise e - async def get_next_expected_seq(self) -> int: """Get the next expected sequence number from the server.""" return self._seq_manager.get_ack() @@ -249,9 +236,8 @@ async def send_message( # If the websocket is closed, we should not send the message # and wait for the retry from the buffer. return - await self._send_transport_message( - msg, - self._ws_wrapper.ws, + await send_transport_message( + msg, self._ws_wrapper.ws, self._begin_close_session_countdown ) except WebsocketClosedException as e: logger.debug( From b1f41b682388091fccbad6bf87efd47207e457b3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:28:25 -0700 Subject: [PATCH 16/29] Unused --- src/replit_river/message_buffer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index 8bcf023c..1365eb81 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -21,11 +21,6 @@ def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE): self._space_available_cond = asyncio.Condition(lock=self._lock) self._closed = False - async def empty(self) -> bool: - """Check if the buffer is empty""" - async with self._lock: - return len(self.buffer) == 0 - async def put(self, message: TransportMessage) -> None: """Add a message to the buffer. Blocks until there is space in the buffer. From b3a22b0953243af300adb601b8a3164c8d34bad3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:29:19 -0700 Subject: [PATCH 17/29] Splitting has_capacity from put() --- src/replit_river/message_buffer.py | 18 ++++++++++-------- src/replit_river/session.py | 3 ++- tests/test_message_buffer.py | 16 ++++++++++++---- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index 1365eb81..3324247d 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -21,19 +21,21 @@ def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE): self._space_available_cond = asyncio.Condition(lock=self._lock) self._closed = False - async def put(self, message: TransportMessage) -> None: + async def has_capacity(self) -> None: + async with self._space_available_cond: + await self._space_available_cond.wait_for( + lambda: len(self.buffer) < self.max_size or self._closed + ) + + def put(self, message: TransportMessage) -> None: """Add a message to the buffer. Blocks until there is space in the buffer. Raises: MessageBufferClosedError: if the buffer is closed. """ - async with self._space_available_cond: - await self._space_available_cond.wait_for( - lambda: len(self.buffer) < self.max_size or self._closed - ) - if self._closed: - raise MessageBufferClosedError("message buffer is closed") - self.buffer.append(message) + if self._closed: + raise MessageBufferClosedError("message buffer is closed") + self.buffer.append(message) async def peek(self) -> TransportMessage | None: """Peek the first message in the buffer, returns None if the buffer is empty.""" diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 99fa407b..30df6c1e 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -226,8 +226,9 @@ async def send_message( # We need this lock to ensure the buffer order and message sending order # are the same. async with self._msg_lock: + await self._buffer.has_capacity() try: - await self._buffer.put(msg) + self._buffer.put(msg) except MessageBufferClosedError: # The session is closed and is no longer accepting new messages. return diff --git a/tests/test_message_buffer.py b/tests/test_message_buffer.py index 7d330375..02a21ccb 100644 --- a/tests/test_message_buffer.py +++ b/tests/test_message_buffer.py @@ -35,7 +35,8 @@ async def test_message_buffer_backpressure() -> None: async def put_messages() -> None: for i in range(0, iterations): - await buffer.put(mock_transport_message(seq=i)) + await buffer.has_capacity() + buffer.put(mock_transport_message(seq=i)) await sync_events.put(None) background_puts = asyncio.create_task(put_messages()) @@ -55,10 +56,17 @@ async def test_message_buffer_close() -> None: is closed while the put operation is waiting for space in the buffer. """ buffer = MessageBuffer(max_num_messages=1) - await buffer.put(mock_transport_message(seq=1)) - background_put = asyncio.create_task(buffer.put(mock_transport_message(seq=1))) + await buffer.has_capacity() + buffer.put(mock_transport_message(seq=1)) + + async def bg_put(msg: TransportMessage) -> None: + await buffer.has_capacity() + buffer.put(msg) + + background_put = asyncio.create_task(bg_put(mock_transport_message(seq=1))) await buffer.close() with pytest.raises(MessageBufferClosedError): await background_put with pytest.raises(MessageBufferClosedError): - await buffer.put(mock_transport_message(seq=1)) + await buffer.has_capacity() + buffer.put(mock_transport_message(seq=1)) From d40ec18158015018118d2a258e89b4916cfdd685 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:30:34 -0700 Subject: [PATCH 18/29] Dedicated method for getting the next seq from the buffer --- src/replit_river/message_buffer.py | 5 +++++ src/replit_river/session.py | 5 +---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index 3324247d..e21c5d0b 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -37,6 +37,11 @@ def put(self, message: TransportMessage) -> None: raise MessageBufferClosedError("message buffer is closed") self.buffer.append(message) + def get_next_sent_seq(self) -> int | None: + if self.buffer: + return self.buffer[0].seq + return None + async def peek(self) -> TransportMessage | None: """Peek the first message in the buffer, returns None if the buffer is empty.""" async with self._lock: diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 30df6c1e..44fbed6e 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -185,10 +185,7 @@ async def get_next_expected_seq(self) -> int: async def get_next_sent_seq(self) -> int: """Get the next sequence number that the client will send.""" - nextMessage = await self._buffer.peek() - if nextMessage: - return nextMessage.seq - return self._seq_manager.get_seq() + return self._buffer.get_next_sent_seq() or self._seq_manager.get_seq() async def get_next_expected_ack(self) -> int: """Get the next expected ack that the client expects.""" From ac95b8eb87046c8cd1535a560bc76d442bfe0b44 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:32:01 -0700 Subject: [PATCH 19/29] Reflowing message_buffer locks --- src/replit_river/message_buffer.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index e21c5d0b..6e1fdad7 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -17,8 +17,7 @@ class MessageBuffer: def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE): self.max_size = max_num_messages self.buffer: list[TransportMessage] = [] - self._lock = asyncio.Lock() - self._space_available_cond = asyncio.Condition(lock=self._lock) + self._space_available_cond = asyncio.Condition() self._closed = False async def has_capacity(self) -> None: @@ -42,23 +41,22 @@ def get_next_sent_seq(self) -> int | None: return self.buffer[0].seq return None - async def peek(self) -> TransportMessage | None: + def peek(self) -> TransportMessage | None: """Peek the first message in the buffer, returns None if the buffer is empty.""" - async with self._lock: - if len(self.buffer) == 0: - return None - return self.buffer[0] + if len(self.buffer) == 0: + return None + return self.buffer[0] async def remove_old_messages(self, min_seq: int) -> None: """Remove messages in the buffer with a seq number less than min_seq.""" - async with self._lock: - self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq] + self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq] + async with self._space_available_cond: self._space_available_cond.notify_all() async def close(self) -> None: """ Closes the message buffer and rejects any pending put operations. """ - async with self._lock: - self._closed = True + self._closed = True + async with self._space_available_cond: self._space_available_cond.notify_all() From a4e375e7ba6660dcec04f156ba76ec5d52f48696 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:36:36 -0700 Subject: [PATCH 20/29] Fix from_ alias type hinting --- src/replit_river/rpc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index f00db903..a0ff5dcf 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -2,6 +2,7 @@ import logging from collections.abc import AsyncIterable, AsyncIterator from typing import ( + Annotated, Any, Awaitable, Callable, @@ -91,7 +92,7 @@ class PropagationContext(BaseModel): class TransportMessage(BaseModel): id: str # from_ is used instead of from because from is a reserved keyword in Python - from_: str = Field(..., alias="from") + from_: Annotated[str, Field(alias="from")] to: str seq: int ack: int From e5262647af958831bc4bce6b30e6f765dfcff3b8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:38:12 -0700 Subject: [PATCH 21/29] Fix test_rpc codegen invocation Module-level imports cause unrecoverable errors on test failure --- tests/codegen/test_rpc.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/codegen/test_rpc.py b/tests/codegen/test_rpc.py index c5483432..8ab82095 100644 --- a/tests/codegen/test_rpc.py +++ b/tests/codegen/test_rpc.py @@ -1,5 +1,6 @@ import asyncio import importlib +import os import shutil from datetime import timedelta from pathlib import Path @@ -19,9 +20,8 @@ @pytest.fixture(scope="session", autouse=True) def generate_rpc_client() -> None: - import tests.codegen.rpc.generated - - shutil.rmtree("tests/codegen/rpc/generated") + shutil.rmtree("tests/codegen/rpc/generated", ignore_errors=True) + os.makedirs("tests/codegen/rpc/generated") def file_opener(path: Path) -> TextIO: return open(path, "w") @@ -34,6 +34,12 @@ def file_opener(path: Path) -> TextIO: file_opener=file_opener, method_filter=None, ) + + +@pytest.fixture(scope="session", autouse=True) +def reload_rpc_import(generate_rpc_client: None) -> None: + import tests.codegen.rpc.generated + importlib.reload(tests.codegen.rpc.generated) From a384231821bba1372312c0be1e4db4e89cc81840 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:38:58 -0700 Subject: [PATCH 22/29] Use binary literal instead of hex literal for alignment with PROTOCOL.md --- src/replit_river/rpc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index a0ff5dcf..3459bf7d 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -47,8 +47,8 @@ GenericRpcHandlerBuilder = Callable[ [str, Channel[Any], Channel[Any]], Coroutine[None, None, None] ] -ACK_BIT = 0x0001 -STREAM_OPEN_BIT = 0x0002 +ACK_BIT = 0b00001 +STREAM_OPEN_BIT = 0b00010 # these codes are retriable # if the server sends a response with one of these codes, From 85e63cd232fc37fdb3f5fa2fea780fb73fbaac59 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:42:46 -0700 Subject: [PATCH 23/29] Equivalent --- src/replit_river/server_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 4af7852f..b869a64e 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -143,7 +143,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: if msg.controlFlags & ACK_BIT != 0: continue async with self._stream_lock: - stream = self._streams.get(msg.streamId, None) + stream = self._streams.get(msg.streamId) if msg.controlFlags & STREAM_OPEN_BIT == 0: if not stream: logger.warning("no stream for %s", msg.streamId) From 9880be772afdc3137ef10da9e11571ad9adf5af7 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:45:37 -0700 Subject: [PATCH 24/29] Useless lock --- src/replit_river/client_session.py | 6 ++---- src/replit_river/server_session.py | 10 ++++------ src/replit_river/session.py | 4 +--- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index fc5ef7f0..d23f9d7c 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -147,8 +147,7 @@ async def _handle_messages_from_ws(self) -> None: if msg.controlFlags & ACK_BIT != 0: continue - async with self._stream_lock: - stream = self._streams.get(msg.streamId, None) + stream = self._streams.get(msg.streamId, None) if msg.controlFlags & STREAM_OPEN_BIT == 0: if not stream: logger.warning("no stream for %s", msg.streamId) @@ -178,8 +177,7 @@ async def _handle_messages_from_ws(self) -> None: if msg.controlFlags & STREAM_CLOSED_BIT != 0: if stream: stream.close() - async with self._stream_lock: - del self._streams[msg.streamId] + del self._streams[msg.streamId] except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await ws_wrapper.close() diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index b869a64e..50742b8c 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -32,6 +32,7 @@ logger = logging.getLogger(__name__) + trace_propagator = TraceContextTextMapPropagator() trace_setter = TransportMessageTracingSetter() @@ -142,8 +143,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: if msg.controlFlags & ACK_BIT != 0: continue - async with self._stream_lock: - stream = self._streams.get(msg.streamId) + stream = self._streams.get(msg.streamId) if msg.controlFlags & STREAM_OPEN_BIT == 0: if not stream: logger.warning("no stream for %s", msg.streamId) @@ -170,15 +170,13 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: if isinstance(_stream, IgnoreMessage): continue if not stream: - async with self._stream_lock: - self._streams[msg.streamId] = _stream + self._streams[msg.streamId] = _stream stream = _stream if msg.controlFlags & STREAM_CLOSED_BIT != 0: if stream: stream.close() - async with self._stream_lock: - del self._streams[msg.streamId] + del self._streams[msg.streamId] except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await ws_wrapper.close() diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 44fbed6e..d8352733 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -75,7 +75,6 @@ def __init__( self._retry_connection_callback = retry_connection_callback # stream for tasks - self._stream_lock = asyncio.Lock() self._streams: dict[str, Channel[Any]] = {} # book keeping @@ -287,8 +286,7 @@ async def close(self) -> None: # throw exception correctly. for stream in self._streams.values(): stream.close() - async with self._stream_lock: - self._streams.clear() + self._streams.clear() self._state = SessionState.CLOSED From 525d9acf69fd2982af72f9595e6e8a62aa31127d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:45:51 -0700 Subject: [PATCH 25/29] Unused --- src/replit_river/server_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 50742b8c..3a931274 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -191,7 +191,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: async def _open_stream_and_call_handler( self, msg: TransportMessage, - tg: asyncio.TaskGroup | None, + tg: asyncio.TaskGroup, ) -> Channel | IgnoreMessage: if not msg.serviceName or not msg.procedureName: logger.warning( From e591d11d09d8fc50dcaa0f9660ec00556615544f Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 22:46:10 -0700 Subject: [PATCH 26/29] More field types --- src/replit_river/session.py | 44 +++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index d8352733..f5d36133 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any, Awaitable, Callable, Coroutine +from typing import Any, Awaitable, Callable, Coroutine, TypeAlias import nanoid import websockets @@ -37,10 +37,42 @@ trace_propagator = TraceContextTextMapPropagator() trace_setter = TransportMessageTracingSetter() +CloseSessionCallback: TypeAlias = Callable[["Session"], Coroutine[Any, Any, Any]] +RetryConnectionCallback: TypeAlias = Callable[ + [], + Coroutine[Any, Any, Any], +] + class Session: """Common functionality shared between client_session and server_session""" + _transport_id: str + _to_id: str + session_id: str + _transport_options: TransportOptions + + # session state + _state: SessionState + _state_lock: asyncio.Lock + _close_session_callback: CloseSessionCallback + _close_session_after_time_secs: float | None + + # ws state + _ws_lock: asyncio.Lock + _ws_wrapper: WebsocketWrapper + _heartbeat_misses: int + _retry_connection_callback: RetryConnectionCallback | None + + # stream for tasks + _streams: dict[str, Channel[Any]] + + # book keeping + _seq_manager: SeqManager + _msg_lock: asyncio.Lock + _buffer: MessageBuffer + _task_manager: BackgroundTaskManager + def __init__( self, transport_id: str, @@ -48,14 +80,8 @@ def __init__( session_id: str, websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, - close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], - retry_connection_callback: ( - Callable[ - [], - Coroutine[Any, Any, Any], - ] - | None - ) = None, + close_session_callback: CloseSessionCallback, + retry_connection_callback: RetryConnectionCallback | None = None, ) -> None: self._transport_id = transport_id self._to_id = to_id From 8afe6e8a5dd998c3b922ce8d523732391aef29ee Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 28 Mar 2025 08:52:32 -0700 Subject: [PATCH 27/29] Removing last useless Lock() --- src/replit_river/session.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index f5d36133..ac01ffba 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -69,7 +69,6 @@ class Session: # book keeping _seq_manager: SeqManager - _msg_lock: asyncio.Lock _buffer: MessageBuffer _task_manager: BackgroundTaskManager @@ -105,7 +104,6 @@ def __init__( # book keeping self._seq_manager = SeqManager() - self._msg_lock = asyncio.Lock() self._buffer = MessageBuffer(self._transport_options.buffer_size) self._task_manager = BackgroundTaskManager() @@ -229,6 +227,8 @@ async def send_message( # if the session is not active, we should not do anything if self._state != SessionState.ACTIVE: return + await self._buffer.has_capacity() + # Start of critical section. No await between here and buffer.put()! msg = TransportMessage( streamId=stream_id, id=nanoid.generate(), @@ -245,23 +245,19 @@ async def send_message( with use_span(span): trace_propagator.inject(msg, None, trace_setter) try: - # We need this lock to ensure the buffer order and message sending order - # are the same. - async with self._msg_lock: - await self._buffer.has_capacity() - try: - self._buffer.put(msg) - except MessageBufferClosedError: - # The session is closed and is no longer accepting new messages. + try: + self._buffer.put(msg) + except MessageBufferClosedError: + # The session is closed and is no longer accepting new messages. + return + async with self._ws_lock: + if not self._ws_wrapper.is_open(): + # If the websocket is closed, we should not send the message + # and wait for the retry from the buffer. return - async with self._ws_lock: - if not self._ws_wrapper.is_open(): - # If the websocket is closed, we should not send the message - # and wait for the retry from the buffer. - return - await send_transport_message( - msg, self._ws_wrapper.ws, self._begin_close_session_countdown - ) + await send_transport_message( + msg, self._ws_wrapper.ws, self._begin_close_session_countdown + ) except WebsocketClosedException as e: logger.debug( "Connection closed while sending message %r, waiting for " From 93bb7f62ba817b4162f0c25207410bca17a6c148 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sun, 30 Mar 2025 10:09:56 -0700 Subject: [PATCH 28/29] Restart serve() on ws disconnect --- src/replit_river/client_session.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index d23f9d7c..2d1e847a 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -82,6 +82,13 @@ async def do_close_websocket() -> None: self._setup_heartbeats_task(do_close_websocket) + async def replace_with_new_websocket( + self, new_ws: websockets.WebSocketCommonProtocol + ) -> None: + await super().replace_with_new_websocket(new_ws) + # serve() terminates itself when the ws dies, so we need to start it again + await self.start_serve_responses() + async def start_serve_responses(self) -> None: self._task_manager.create_task(self.serve()) From 6d6c9532c32d9908839fe4d436a150b44c7ad0c2 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sun, 30 Mar 2025 10:22:23 -0700 Subject: [PATCH 29/29] Only dereference --method-filter when generating clients --- src/replit_river/codegen/run.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index 2eb9020b..16d94a5e 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -48,11 +48,6 @@ def main() -> None: client.add_argument("schema", help="schema file") args = parser.parse_args() - method_filter: set[str] | None = None - if args.method_filter: - with open(args.method_filter) as handle: - method_filter = set(x.strip() for x in handle.readlines()) - if args.command == "server": proto_path = os.path.abspath(args.proto) target_directory = os.path.abspath(args.output) @@ -68,6 +63,11 @@ def main() -> None: def file_opener(path: Path) -> TextIO: return open(path, "w") + method_filter: set[str] | None = None + if args.method_filter: + with open(args.method_filter) as handle: + method_filter = set(x.strip() for x in handle.readlines()) + schema_to_river_client_codegen( lambda: open(schema_path), target_path,