Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1ae427c
lint tests as well
blast-hardcheese Mar 28, 2025
76f9d3a
Permit method filtering based on supplied file
blast-hardcheese Mar 26, 2025
bbd2049
Missing invocation of validate_python
blast-hardcheese Mar 28, 2025
64b1b89
inline add_msg_to_stream
blast-hardcheese Mar 28, 2025
90016ee
Moving STREAM_CLOSED_BIT into v1 session objects for clarity in prepa…
blast-hardcheese Mar 28, 2025
7d3b6ba
parse_transport_msg should just return a value
blast-hardcheese Mar 28, 2025
4d86045
Switch from IgnoreMessageException to IgnoreMessage return value
blast-hardcheese Mar 28, 2025
607bacf
Missing ws.close()
blast-hardcheese Mar 28, 2025
c2ecb77
Distribute PROTOCOL_VERSION through the different files it belongs in
blast-hardcheese Mar 28, 2025
0775f21
Moving setup_heartbeat to session.py
blast-hardcheese Mar 28, 2025
84a480b
Representing a richer tapestry of SessionState
blast-hardcheese Mar 28, 2025
dabc4b8
We have nanoid types
blast-hardcheese Mar 28, 2025
e6dddb4
Move check_to_close_session over to session
blast-hardcheese Mar 28, 2025
5c3b91a
Turns out none of this was async anyhow
blast-hardcheese Mar 28, 2025
d6650e7
Flattening unnecessary wrapper function
blast-hardcheese Mar 28, 2025
b1f41b6
Unused
blast-hardcheese Mar 28, 2025
b3a22b0
Splitting has_capacity from put()
blast-hardcheese Mar 28, 2025
d40ec18
Dedicated method for getting the next seq from the buffer
blast-hardcheese Mar 28, 2025
ac95b8e
Reflowing message_buffer locks
blast-hardcheese Mar 28, 2025
a4e375e
Fix from_ alias type hinting
blast-hardcheese Mar 28, 2025
e526264
Fix test_rpc codegen invocation
blast-hardcheese Mar 28, 2025
a384231
Use binary literal instead of hex literal for alignment with PROTOCOL.md
blast-hardcheese Mar 28, 2025
85e63cd
Equivalent
blast-hardcheese Mar 28, 2025
9880be7
Useless lock
blast-hardcheese Mar 28, 2025
525d9ac
Unused
blast-hardcheese Mar 28, 2025
e591d11
More field types
blast-hardcheese Mar 28, 2025
8afe6e8
Removing last useless Lock()
blast-hardcheese Mar 28, 2025
93bb7f6
Restart serve() on ws disconnect
blast-hardcheese Mar 30, 2025
6d6c953
Only dereference --method-filter when generating clients
blast-hardcheese Mar 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions scripts/lint/src/lint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def raise_err(code: int) -> None:

def main() -> None:
fix = ["--fix"] if "--fix" in sys.argv else []
watch = ["--watch"] if "--watch" in sys.argv else []
raise_err(os.system(" ".join(["ruff", "check", "src", "scripts", "tests"] + fix)))
raise_err(os.system("ruff format src scripts tests"))
raise_err(os.system("mypy src"))
raise_err(os.system("pyright src"))
raise_err(os.system("mypy src tests"))
raise_err(os.system(" ".join(["pyright"] + watch + ["src", "tests"])))
65 changes: 46 additions & 19 deletions src/replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
import logging
from collections.abc import AsyncIterable
from datetime import timedelta
from typing import Any, AsyncGenerator, Callable, Coroutine
from typing import Any, AsyncGenerator, Callable, Coroutine, assert_never

import nanoid # type: ignore
import nanoid
import websockets
from aiochannel import Channel
from aiochannel.errors import ChannelClosed
from opentelemetry.trace import Span
from websockets.exceptions import ConnectionClosed

from replit_river.common_session import add_msg_to_stream
from replit_river.error_schema import (
ERROR_CODE_CANCEL,
ERROR_CODE_STREAM_CLOSED,
Expand All @@ -25,7 +24,7 @@
parse_transport_msg,
)
from replit_river.seq_manager import (
IgnoreMessageException,
IgnoreMessage,
InvalidMessageException,
OutOfOrderMessageException,
)
Expand All @@ -34,7 +33,6 @@

from .rpc import (
ACK_BIT,
STREAM_CLOSED_BIT,
STREAM_OPEN_BIT,
ErrorType,
InitType,
Expand All @@ -45,6 +43,9 @@
logger = logging.getLogger(__name__)


STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2


class ClientSession(Session):
def __init__(
self,
Expand Down Expand Up @@ -81,6 +82,13 @@ async def do_close_websocket() -> None:

self._setup_heartbeats_task(do_close_websocket)

async def replace_with_new_websocket(
self, new_ws: websockets.WebSocketCommonProtocol
) -> None:
await super().replace_with_new_websocket(new_ws)
# serve() terminates itself when the ws dies, so we need to start it again
await self.start_serve_responses()

async def start_serve_responses(self) -> None:
self._task_manager.create_task(self.serve())

Expand Down Expand Up @@ -120,31 +128,54 @@ async def _handle_messages_from_ws(self) -> None:
ws_wrapper = self._ws_wrapper
async for message in ws_wrapper.ws:
try:
if not await ws_wrapper.is_open():
if not ws_wrapper.is_open():
# We should not process messages if the websocket is closed.
break
msg = parse_transport_msg(message, self._transport_options)
msg = parse_transport_msg(message)
if isinstance(msg, str):
logger.debug("Ignoring transport message", exc_info=True)
continue

logger.debug(f"{self._transport_id} got a message %r", msg)

# Update bookkeeping
await self._seq_manager.check_seq_and_update(msg)
match self._seq_manager.check_seq_and_update(msg):
case IgnoreMessage():
continue
case None:
pass
case other:
assert_never(other)

await self._buffer.remove_old_messages(
self._seq_manager.receiver_ack,
)
self._reset_session_close_countdown()

if msg.controlFlags & ACK_BIT != 0:
continue
async with self._stream_lock:
stream = self._streams.get(msg.streamId, None)
stream = self._streams.get(msg.streamId, None)
if msg.controlFlags & STREAM_OPEN_BIT == 0:
if not stream:
logger.warning("no stream for %s", msg.streamId)
raise IgnoreMessageException(
"no stream for message, ignoring"
)
await add_msg_to_stream(msg, stream)
continue

if (
msg.controlFlags & STREAM_CLOSED_BIT != 0
and msg.payload.get("type", None) == "CLOSE"
):
# close message is not sent to the stream
pass
else:
try:
await stream.put(msg.payload)
except ChannelClosed:
# The client is no longer interested in this stream,
# just drop the message.
pass
except RuntimeError as e:
raise InvalidMessageException(e) from e

else:
raise InvalidMessageException(
"Client should not receive stream open bit"
Expand All @@ -153,11 +184,7 @@ async def _handle_messages_from_ws(self) -> None:
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
if stream:
stream.close()
async with self._stream_lock:
del self._streams[msg.streamId]
except IgnoreMessageException:
logger.debug("Ignoring transport message", exc_info=True)
continue
del self._streams[msg.streamId]
except OutOfOrderMessageException:
logger.exception("Out of order message, closing connection")
await ws_wrapper.close()
Expand Down
15 changes: 9 additions & 6 deletions src/replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
RiverException,
)
from replit_river.messages import (
PROTOCOL_VERSION,
FailedSendingMessageException,
WebsocketClosedException,
parse_transport_msg,
Expand All @@ -34,7 +33,6 @@
TransportMessage,
)
from replit_river.seq_manager import (
IgnoreMessageException,
InvalidMessageException,
)
from replit_river.session import Session
Expand All @@ -44,6 +42,8 @@
UriAndMetadata,
)

PROTOCOL_VERSION = "v1.1"

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -116,6 +116,8 @@ async def get_or_create_session(self) -> ClientSession:
return existing_session
else:
logger.info("Closing stale session %s", existing_session.session_id)
await new_ws.close() # NB(dstewart): This wasn't there in the
# v1 transport, were we just leaking WS?
await existing_session.close()
return await self._create_new_session()

Expand Down Expand Up @@ -293,10 +295,11 @@ async def _get_handshake_response_msg(
"Handshake failed, conn closed while waiting for response",
) from e
try:
return parse_transport_msg(data, self._transport_options)
except IgnoreMessageException:
logger.debug("Ignoring transport message", exc_info=True)
continue
msg = parse_transport_msg(data)
if isinstance(msg, str):
logger.debug("Ignoring transport message", exc_info=True)
continue
return msg
except InvalidMessageException as e:
raise RiverException(
ERROR_HANDSHAKE,
Expand Down
23 changes: 18 additions & 5 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ def generate_individual_service(
schema_name: str,
schema: RiverService,
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
method_filter: set[str] | None,
) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []

Expand Down Expand Up @@ -837,6 +838,8 @@ def __init__(self, client: river.Client[Any]):
),
]
for name, procedure in schema.procedures.items():
if method_filter and (schema_name + "." + name) in method_filter:
continue
module_names = [ModuleName(name)]
init_type: TypeExpression | None = None
if procedure.init:
Expand Down Expand Up @@ -980,7 +983,7 @@ def __init__(self, client: river.Client[Any]):
)
render_init_method = f"""\
lambda x: {render_type_expr(init_type_type_adapter_name)}
.validate_python
.validate_python(x)
"""

assert init_type is None or render_init_method, (
Expand Down Expand Up @@ -1223,6 +1226,7 @@ def generate_river_client_module(
client_name: str,
schema_root: RiverSchema,
typed_dict_inputs: bool,
method_filter: set[str] | None,
) -> dict[RenderedPath, FileContents]:
files: dict[RenderedPath, FileContents] = {}

Expand All @@ -1247,10 +1251,15 @@ def generate_river_client_module(
)
for schema_name, schema in schema_root.services.items():
module_name, class_name, emitted_files = generate_individual_service(
schema_name, schema, input_base_class
schema_name,
schema,
input_base_class,
method_filter,
)
files.update(emitted_files)
modules.append((module_name, class_name))
if emitted_files:
# Short-cut if we didn't actually emit anything
files.update(emitted_files)
modules.append((module_name, class_name))

main_contents = generate_common_client(
client_name, handshake_type, handshake_chunks, modules
Expand All @@ -1266,12 +1275,16 @@ def schema_to_river_client_codegen(
client_name: str,
typed_dict_inputs: bool,
file_opener: Callable[[Path], TextIO],
method_filter: set[str] | None,
) -> None:
"""Generates the lines of a River module."""
with read_schema() as f:
schemas = RiverSchemaFile(json.load(f))
for subpath, contents in generate_river_client_module(
client_name, schemas.root, typed_dict_inputs
client_name,
schemas.root,
typed_dict_inputs,
method_filter,
).items():
module_path = Path(target_path).joinpath(subpath)
module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True)
Expand Down
13 changes: 13 additions & 0 deletions src/replit_river/codegen/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os.path
import pathlib
from pathlib import Path
from typing import TextIO

Expand Down Expand Up @@ -38,6 +39,12 @@ def main() -> None:
action="store_true",
default=False,
)
client.add_argument(
"--method-filter",
help="Only generate a subset of the specified methods",
action="store",
type=pathlib.Path,
)
client.add_argument("schema", help="schema file")
args = parser.parse_args()

Expand All @@ -56,12 +63,18 @@ def main() -> None:
def file_opener(path: Path) -> TextIO:
return open(path, "w")

method_filter: set[str] | None = None
if args.method_filter:
with open(args.method_filter) as handle:
method_filter = set(x.strip() for x in handle.readlines())

schema_to_river_client_codegen(
lambda: open(schema_path),
target_path,
args.client_name,
args.typed_dict_inputs,
file_opener,
method_filter=method_filter,
)
else:
raise NotImplementedError(f"Unknown command {args.command}")
Loading
Loading