Skip to content

Commit 0046cac

Browse files
Permit method filtering based on supplied file
1 parent 6b17c42 commit 0046cac

File tree

4 files changed

+26
-2
lines changed

4 files changed

+26
-2
lines changed

src/replit_river/codegen/client.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,7 @@ def generate_individual_service(
11131113
schema: RiverService,
11141114
input_base_class: Literal["TypedDict"] | Literal["BaseModel"],
11151115
protocol_version: Literal["v1.1", "v2.0"],
1116+
method_filter: set[str] | None,
11161117
) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]:
11171118
serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = []
11181119

@@ -1155,6 +1156,8 @@ def __init__(self, client: {client_module}.Client[Any]):
11551156
),
11561157
]
11571158
for name, procedure in schema.procedures.items():
1159+
if method_filter and (schema_name + "." + name) in method_filter:
1160+
continue
11581161
module_names = [ModuleName(name)]
11591162
init_type: TypeExpression | None = None
11601163
init_module_info: list[ModuleName] = []
@@ -1408,6 +1411,7 @@ def generate_river_client_module(
14081411
schema_root: RiverSchema,
14091412
typed_dict_inputs: bool,
14101413
protocol_version: Literal["v1.1", "v2.0"],
1414+
method_filter: set[str] | None,
14111415
) -> dict[RenderedPath, FileContents]:
14121416
files: dict[RenderedPath, FileContents] = {}
14131417

@@ -1436,9 +1440,12 @@ def generate_river_client_module(
14361440
schema,
14371441
input_base_class,
14381442
protocol_version,
1443+
method_filter,
14391444
)
1440-
files.update(emitted_files)
1441-
modules.append((module_name, class_name))
1445+
if emitted_files:
1446+
# Short-cut if we didn't actually emit anything
1447+
files.update(emitted_files)
1448+
modules.append((module_name, class_name))
14421449

14431450
main_contents = generate_common_client(
14441451
client_name,
@@ -1459,6 +1466,7 @@ def schema_to_river_client_codegen(
14591466
typed_dict_inputs: bool,
14601467
file_opener: Callable[[Path], TextIO],
14611468
protocol_version: Literal["v1.1", "v2.0"],
1469+
method_filter: set[str] | None,
14621470
) -> None:
14631471
"""Generates the lines of a River module."""
14641472
with read_schema() as f:
@@ -1468,6 +1476,7 @@ def schema_to_river_client_codegen(
14681476
schemas.root,
14691477
typed_dict_inputs,
14701478
protocol_version,
1479+
method_filter,
14711480
).items():
14721481
module_path = Path(target_path).joinpath(subpath)
14731482
module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True)

src/replit_river/codegen/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import os.path
3+
import pathlib
34
from pathlib import Path
45
from typing import TextIO
56

@@ -45,9 +46,20 @@ def main() -> None:
4546
default="v1.1",
4647
choices=["v1.1", "v2.0"],
4748
)
49+
client.add_argument(
50+
"--method-filter",
51+
help="Only generate a subset of the specified methods",
52+
action="store",
53+
type=pathlib.Path,
54+
)
4855
client.add_argument("schema", help="schema file")
4956
args = parser.parse_args()
5057

58+
method_filter: set[str] | None = None
59+
if args.method_filter:
60+
with open(args.method_filter) as handle:
61+
method_filter = set(x.strip() for x in handle.readlines())
62+
5163
if args.command == "server":
5264
proto_path = os.path.abspath(args.proto)
5365
target_directory = os.path.abspath(args.output)
@@ -70,6 +82,7 @@ def file_opener(path: Path) -> TextIO:
7082
typed_dict_inputs=args.typed_dict_inputs,
7183
file_opener=file_opener,
7284
protocol_version=args.protocol_version,
85+
method_filter=method_filter,
7386
)
7487
else:
7588
raise NotImplementedError(f"Unknown command {args.command}")

tests/codegen/snapshot/codegen_snapshot_fixtures.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def file_opener(path: Path) -> TextIO:
3535
file_opener=file_opener,
3636
typed_dict_inputs=True,
3737
protocol_version="v1.1",
38+
method_filter=None,
3839
)
3940
for path, file in files.items():
4041
file.seek(0)

tests/codegen/test_rpc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def file_opener(path: Path) -> TextIO:
3333
typed_dict_inputs=True,
3434
file_opener=file_opener,
3535
protocol_version="v1.1",
36+
method_filter=None,
3637
)
3738
importlib.reload(tests.codegen.rpc.generated)
3839

0 commit comments

Comments
 (0)