Skip to content

Commit fa382a7

Browse files
chore/v1 improvements (#149)
Why === Split out of #106 for readability What changed ============ - Code organization - Tooling tweaks - Removing useless `Lock`s Test plan ========= CI river-babel
1 parent 38554fc commit fa382a7

18 files changed

+414
-361
lines changed

scripts/lint/src/lint/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ def raise_err(code: int) -> None:
1111

1212
def main() -> None:
1313
fix = ["--fix"] if "--fix" in sys.argv else []
14+
watch = ["--watch"] if "--watch" in sys.argv else []
1415
raise_err(os.system(" ".join(["ruff", "check", "src", "scripts", "tests"] + fix)))
1516
raise_err(os.system("ruff format src scripts tests"))
16-
raise_err(os.system("mypy src"))
17-
raise_err(os.system("pyright src"))
17+
raise_err(os.system("mypy src tests"))
18+
raise_err(os.system(" ".join(["pyright"] + watch + ["src", "tests"])))

src/replit_river/client_session.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
import logging
33
from collections.abc import AsyncIterable
44
from datetime import timedelta
5-
from typing import Any, AsyncGenerator, Callable, Coroutine
5+
from typing import Any, AsyncGenerator, Callable, Coroutine, assert_never
66

7-
import nanoid # type: ignore
7+
import nanoid
88
import websockets
99
from aiochannel import Channel
1010
from aiochannel.errors import ChannelClosed
1111
from opentelemetry.trace import Span
1212
from websockets.exceptions import ConnectionClosed
1313

14-
from replit_river.common_session import add_msg_to_stream
1514
from replit_river.error_schema import (
1615
ERROR_CODE_CANCEL,
1716
ERROR_CODE_STREAM_CLOSED,
@@ -25,7 +24,7 @@
2524
parse_transport_msg,
2625
)
2726
from replit_river.seq_manager import (
28-
IgnoreMessageException,
27+
IgnoreMessage,
2928
InvalidMessageException,
3029
OutOfOrderMessageException,
3130
)
@@ -34,7 +33,6 @@
3433

3534
from .rpc import (
3635
ACK_BIT,
37-
STREAM_CLOSED_BIT,
3836
STREAM_OPEN_BIT,
3937
ErrorType,
4038
InitType,
@@ -45,6 +43,9 @@
4543
logger = logging.getLogger(__name__)
4644

4745

46+
STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2
47+
48+
4849
class ClientSession(Session):
4950
def __init__(
5051
self,
@@ -81,6 +82,13 @@ async def do_close_websocket() -> None:
8182

8283
self._setup_heartbeats_task(do_close_websocket)
8384

85+
async def replace_with_new_websocket(
86+
self, new_ws: websockets.WebSocketCommonProtocol
87+
) -> None:
88+
await super().replace_with_new_websocket(new_ws)
89+
# serve() terminates itself when the ws dies, so we need to start it again
90+
await self.start_serve_responses()
91+
8492
async def start_serve_responses(self) -> None:
8593
self._task_manager.create_task(self.serve())
8694

@@ -120,31 +128,54 @@ async def _handle_messages_from_ws(self) -> None:
120128
ws_wrapper = self._ws_wrapper
121129
async for message in ws_wrapper.ws:
122130
try:
123-
if not await ws_wrapper.is_open():
131+
if not ws_wrapper.is_open():
124132
# We should not process messages if the websocket is closed.
125133
break
126-
msg = parse_transport_msg(message, self._transport_options)
134+
msg = parse_transport_msg(message)
135+
if isinstance(msg, str):
136+
logger.debug("Ignoring transport message", exc_info=True)
137+
continue
127138

128139
logger.debug(f"{self._transport_id} got a message %r", msg)
129140

130141
# Update bookkeeping
131-
await self._seq_manager.check_seq_and_update(msg)
142+
match self._seq_manager.check_seq_and_update(msg):
143+
case IgnoreMessage():
144+
continue
145+
case None:
146+
pass
147+
case other:
148+
assert_never(other)
149+
132150
await self._buffer.remove_old_messages(
133151
self._seq_manager.receiver_ack,
134152
)
135153
self._reset_session_close_countdown()
136154

137155
if msg.controlFlags & ACK_BIT != 0:
138156
continue
139-
async with self._stream_lock:
140-
stream = self._streams.get(msg.streamId, None)
157+
stream = self._streams.get(msg.streamId, None)
141158
if msg.controlFlags & STREAM_OPEN_BIT == 0:
142159
if not stream:
143160
logger.warning("no stream for %s", msg.streamId)
144-
raise IgnoreMessageException(
145-
"no stream for message, ignoring"
146-
)
147-
await add_msg_to_stream(msg, stream)
161+
continue
162+
163+
if (
164+
msg.controlFlags & STREAM_CLOSED_BIT != 0
165+
and msg.payload.get("type", None) == "CLOSE"
166+
):
167+
# close message is not sent to the stream
168+
pass
169+
else:
170+
try:
171+
await stream.put(msg.payload)
172+
except ChannelClosed:
173+
# The client is no longer interested in this stream,
174+
# just drop the message.
175+
pass
176+
except RuntimeError as e:
177+
raise InvalidMessageException(e) from e
178+
148179
else:
149180
raise InvalidMessageException(
150181
"Client should not receive stream open bit"
@@ -153,11 +184,7 @@ async def _handle_messages_from_ws(self) -> None:
153184
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
154185
if stream:
155186
stream.close()
156-
async with self._stream_lock:
157-
del self._streams[msg.streamId]
158-
except IgnoreMessageException:
159-
logger.debug("Ignoring transport message", exc_info=True)
160-
continue
187+
del self._streams[msg.streamId]
161188
except OutOfOrderMessageException:
162189
logger.exception("Out of order message, closing connection")
163190
await ws_wrapper.close()

src/replit_river/client_transport.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
RiverException,
2020
)
2121
from replit_river.messages import (
22-
PROTOCOL_VERSION,
2322
FailedSendingMessageException,
2423
WebsocketClosedException,
2524
parse_transport_msg,
@@ -34,7 +33,6 @@
3433
TransportMessage,
3534
)
3635
from replit_river.seq_manager import (
37-
IgnoreMessageException,
3836
InvalidMessageException,
3937
)
4038
from replit_river.session import Session
@@ -44,6 +42,8 @@
4442
UriAndMetadata,
4543
)
4644

45+
PROTOCOL_VERSION = "v1.1"
46+
4747
logger = logging.getLogger(__name__)
4848

4949

@@ -116,6 +116,8 @@ async def get_or_create_session(self) -> ClientSession:
116116
return existing_session
117117
else:
118118
logger.info("Closing stale session %s", existing_session.session_id)
119+
await new_ws.close() # NB(dstewart): This wasn't there in the
120+
# v1 transport, were we just leaking WS?
119121
await existing_session.close()
120122
return await self._create_new_session()
121123

@@ -293,10 +295,11 @@ async def _get_handshake_response_msg(
293295
"Handshake failed, conn closed while waiting for response",
294296
) from e
295297
try:
296-
return parse_transport_msg(data, self._transport_options)
297-
except IgnoreMessageException:
298-
logger.debug("Ignoring transport message", exc_info=True)
299-
continue
298+
msg = parse_transport_msg(data)
299+
if isinstance(msg, str):
300+
logger.debug("Ignoring transport message", exc_info=True)
301+
continue
302+
return msg
300303
except InvalidMessageException as e:
301304
raise RiverException(
302305
ERROR_HANDSHAKE,

src/replit_river/codegen/client.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ def generate_individual_service(
802802
schema_name: str,
803803
schema: RiverService,
804804
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
805+
method_filter: set[str] | None,
805806
) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
806807
serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []
807808

@@ -837,6 +838,8 @@ def __init__(self, client: river.Client[Any]):
837838
),
838839
]
839840
for name, procedure in schema.procedures.items():
841+
if method_filter and (schema_name + "." + name) in method_filter:
842+
continue
840843
module_names = [ModuleName(name)]
841844
init_type: TypeExpression | None = None
842845
if procedure.init:
@@ -980,7 +983,7 @@ def __init__(self, client: river.Client[Any]):
980983
)
981984
render_init_method = f"""\
982985
lambda x: {render_type_expr(init_type_type_adapter_name)}
983-
.validate_python
986+
.validate_python(x)
984987
"""
985988

986989
assert init_type is None or render_init_method, (
@@ -1223,6 +1226,7 @@ def generate_river_client_module(
12231226
client_name: str,
12241227
schema_root: RiverSchema,
12251228
typed_dict_inputs: bool,
1229+
method_filter: set[str] | None,
12261230
) -> dict[RenderedPath, FileContents]:
12271231
files: dict[RenderedPath, FileContents] = {}
12281232

@@ -1247,10 +1251,15 @@ def generate_river_client_module(
12471251
)
12481252
for schema_name, schema in schema_root.services.items():
12491253
module_name, class_name, emitted_files = generate_individual_service(
1250-
schema_name, schema, input_base_class
1254+
schema_name,
1255+
schema,
1256+
input_base_class,
1257+
method_filter,
12511258
)
1252-
files.update(emitted_files)
1253-
modules.append((module_name, class_name))
1259+
if emitted_files:
1260+
# Short-cut if we didn't actually emit anything
1261+
files.update(emitted_files)
1262+
modules.append((module_name, class_name))
12541263

12551264
main_contents = generate_common_client(
12561265
client_name, handshake_type, handshake_chunks, modules
@@ -1266,12 +1275,16 @@ def schema_to_river_client_codegen(
12661275
client_name: str,
12671276
typed_dict_inputs: bool,
12681277
file_opener: Callable[[Path], TextIO],
1278+
method_filter: set[str] | None,
12691279
) -> None:
12701280
"""Generates the lines of a River module."""
12711281
with read_schema() as f:
12721282
schemas = RiverSchemaFile(json.load(f))
12731283
for subpath, contents in generate_river_client_module(
1274-
client_name, schemas.root, typed_dict_inputs
1284+
client_name,
1285+
schemas.root,
1286+
typed_dict_inputs,
1287+
method_filter,
12751288
).items():
12761289
module_path = Path(target_path).joinpath(subpath)
12771290
module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True)

src/replit_river/codegen/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import os.path
3+
import pathlib
34
from pathlib import Path
45
from typing import TextIO
56

@@ -38,6 +39,12 @@ def main() -> None:
3839
action="store_true",
3940
default=False,
4041
)
42+
client.add_argument(
43+
"--method-filter",
44+
help="Only generate a subset of the specified methods",
45+
action="store",
46+
type=pathlib.Path,
47+
)
4148
client.add_argument("schema", help="schema file")
4249
args = parser.parse_args()
4350

@@ -56,12 +63,18 @@ def main() -> None:
5663
def file_opener(path: Path) -> TextIO:
5764
return open(path, "w")
5865

66+
method_filter: set[str] | None = None
67+
if args.method_filter:
68+
with open(args.method_filter) as handle:
69+
method_filter = set(x.strip() for x in handle.readlines())
70+
5971
schema_to_river_client_codegen(
6072
lambda: open(schema_path),
6173
target_path,
6274
args.client_name,
6375
args.typed_dict_inputs,
6476
file_opener,
77+
method_filter=method_filter,
6578
)
6679
else:
6780
raise NotImplementedError(f"Unknown command {args.command}")

0 commit comments

Comments
 (0)