From 0647cf7d80dc81940862d22e122309f9f497e27d Mon Sep 17 00:00:00 2001 From: PotentialStyx <62217716+PotentialStyx@users.noreply.github.com> Date: Mon, 4 Aug 2025 14:18:11 -0700 Subject: [PATCH 1/2] Fix codegen for fields with dashes --- src/replit_river/codegen/client.py | 28 ++++++++++++++++++++++----- src/replit_river/codegen/typing.py | 31 +++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 815d696c..f5361526 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -24,6 +24,7 @@ FileContents, HandshakeType, ListTypeExpr, + LiteralType, LiteralTypeExpr, ModuleName, NoneTypeExpr, @@ -33,6 +34,7 @@ TypeName, UnionTypeExpr, extract_inner_type, + normalize_special_chars, render_literal_type, render_type_expr, ) @@ -396,9 +398,12 @@ def {_field_name}( case NoneTypeExpr(): typeddict_encoder.append("None") case other: - _o2: DictTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = ( - other - ) + _o2: ( + DictTypeExpr + | OpenUnionTypeExpr + | UnionTypeExpr + | LiteralType + ) = other raise ValueError(f"What does it mean to have {_o2} here?") if permit_unknown_members: union = _make_open_union_type_expr(any_of) @@ -491,7 +496,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: return (NoneTypeExpr(), [], [], set()) elif type.type == "Date": typeddict_encoder.append("TODO: dstewart") - return (TypeName("datetime.datetime"), [], [], set()) + return (LiteralType("datetime.datetime"), [], [], set()) elif type.type == "array" and type.items: type_name, module_info, type_chunks, encoder_names = encode_type( type.items, @@ -692,8 +697,21 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: ) ) else: + specialized_name = normalize_special_chars(name) + effective_name = name + extras = "" + if name != specialized_name: + if base_model != "BaseModel": + # TODO: alias support for TypedDict + raise ValueError( + f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501 + ) + # Pydantic doesn't allow leading underscores in field names + effective_name = specialized_name.lstrip("_") + extras = f" = Field(serialization_alias={repr(name)})" + current_chunks.append( - f" {name}: {render_type_expr(type_name)}" + f" {effective_name}: {render_type_expr(type_name)}{extras}" ) typeddict_encoder.append(",") typeddict_encoder.append("}") diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index 68443ffa..626b2180 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import NewType, assert_never, cast +SPECIAL_CHARS = [".", "-", ":", "/", "@", " ", "$", "!", "?", "=", "&", "|", "~", "`"] + ModuleName = NewType("ModuleName", str) ClassName = NewType("ClassName", str) FileContents = NewType("FileContents", str) @@ -23,6 +25,20 @@ def __lt__(self, other: object) -> bool: return hash(self) < hash(other) +@dataclass(frozen=True) +class LiteralType: + value: str + + def __str__(self) -> str: + raise Exception("Complex type must be put through render_type_expr!") + + def __eq__(self, other: object) -> bool: + return isinstance(other, LiteralType) and other.value == self.value + + def __lt__(self, other: object) -> bool: + return hash(self) < hash(other) + + @dataclass(frozen=True) class NoneTypeExpr: def __str__(self) -> str: @@ -111,6 +127,7 @@ def __lt__(self, other: object) -> bool: TypeExpression = ( TypeName + | LiteralType | NoneTypeExpr | DictTypeExpr | ListTypeExpr @@ -145,6 +162,12 @@ def work( raise ValueError("Incoherent state when trying to flatten unions") +def normalize_special_chars(value: str) -> str: + for char in SPECIAL_CHARS: + value = value.replace(char, "_") + return value + + def render_type_expr(value: TypeExpression) -> str: match _flatten_nested_unions(value): case DictTypeExpr(nested): @@ -192,7 +215,9 @@ def render_type_expr(value: TypeExpression) -> str: "]" ) case TypeName(name): - return name + return normalize_special_chars(name) + case LiteralType(literal_value): + return literal_value case NoneTypeExpr(): return "None" case other: @@ -223,6 +248,10 @@ def extract_inner_type(value: TypeExpression) -> TypeName: ) case TypeName(name): return TypeName(name) + case LiteralType(name): + raise ValueError( + f"Attempting to extract from a literal type: {repr(value)}" + ) case NoneTypeExpr(): raise ValueError( f"Attempting to extract from a literal 'None': {repr(value)}", From 437fece07a425e4c540ead14822fd33be2b0e8c3 Mon Sep 17 00:00:00 2001 From: PotentialStyx <62217716+PotentialStyx@users.noreply.github.com> Date: Tue, 5 Aug 2025 10:55:50 -0700 Subject: [PATCH 2/2] Fix optional and add error on field conflict --- src/replit_river/codegen/client.py | 58 ++++++++++++++----- tests/conftest.py | 7 ++- .../rpc/generated/test_service/rpc_method.py | 8 +++ tests/v1/codegen/rpc/invalid-schema.json | 30 ++++++++++ tests/v1/codegen/rpc/schema.json | 19 +++++- .../enumService/needsEnumObject.py | 8 +-- tests/v1/codegen/test_invalid_schema.py | 27 +++++++++ tests/v1/codegen/test_rpc.py | 7 +-- 8 files changed, 137 insertions(+), 27 deletions(-) create mode 100644 tests/v1/codegen/rpc/invalid-schema.json create mode 100644 tests/v1/codegen/test_invalid_schema.py diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index f5361526..d89359d3 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -1,6 +1,7 @@ import json import re import subprocess +from collections import defaultdict from pathlib import Path from textwrap import dedent from typing import ( @@ -529,6 +530,9 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: # lambda x: ... vs lambda _: {} needs_binding = False encoder_names = set() + # Track effective field names to detect collisions after normalization + # Maps effective name -> list of original field names + effective_field_names: defaultdict[str, list[str]] = defaultdict(list) if type.properties: needs_binding = True typeddict_encoder.append("{") @@ -658,19 +662,37 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: value = "" if base_model != "TypedDict": value = f"= {field_value}" + # Track $kind -> "kind" mapping for collision detection + effective_field_names["kind"].append(name) + current_chunks.append( f" kind: Annotated[{render_type_expr(type_name)}, Field(alias={ repr(name) })]{value}" ) else: + specialized_name = normalize_special_chars(name) + effective_name = name + extras = [] + if name != specialized_name: + if base_model != "BaseModel": + # TODO: alias support for TypedDict + raise ValueError( + f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501 + ) + # Pydantic doesn't allow leading underscores in field names + effective_name = specialized_name.lstrip("_") + extras.append(f"alias={repr(name)}") + + effective_field_names[effective_name].append(name) + if name not in type.required: if base_model == "TypedDict": current_chunks.append( reindent( " ", f"""\ - {name}: NotRequired[{ + {effective_name}: NotRequired[{ render_type_expr( UnionTypeExpr([type_name, NoneTypeExpr()]) ) @@ -679,11 +701,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: ) ) else: + extras.append("default=None") + current_chunks.append( reindent( " ", f"""\ - {name}: { + {effective_name}: { render_type_expr( UnionTypeExpr( [ @@ -692,28 +716,30 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: ] ) ) - } = None + } = Field({", ".join(extras)}) """, ) ) else: - specialized_name = normalize_special_chars(name) - effective_name = name - extras = "" - if name != specialized_name: - if base_model != "BaseModel": - # TODO: alias support for TypedDict - raise ValueError( - f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501 - ) - # Pydantic doesn't allow leading underscores in field names - effective_name = specialized_name.lstrip("_") - extras = f" = Field(serialization_alias={repr(name)})" + extras_str = "" + if len(extras) != 0: + extras_str = f" = Field({', '.join(extras)})" current_chunks.append( - f" {effective_name}: {render_type_expr(type_name)}{extras}" + f" {effective_name}: {render_type_expr(type_name)}{extras_str}" # noqa: E501 ) typeddict_encoder.append(",") + + # Check for field name collisions after processing all fields + for effective_name, original_names in effective_field_names.items(): + if len(original_names) > 1: + error_msg = ( + f"Field name collision: fields {original_names} all normalize " + f"to the same effective name '{effective_name}'" + ) + + raise ValueError(error_msg) + typeddict_encoder.append("}") # exclude_none typeddict_encoder = ( diff --git a/tests/conftest.py b/tests/conftest.py index db928bc5..7615f9b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone from typing import Any, Literal, Mapping import nanoid @@ -55,7 +56,11 @@ def deserialize_request(request: dict) -> str: def serialize_response(response: str) -> dict: - return {"data": response} + return { + "data": response, + "data2": datetime.now(timezone.utc), + "data-3": {"data-test": "test"}, + } def deserialize_response(response: dict) -> str: diff --git a/tests/v1/codegen/rpc/generated/test_service/rpc_method.py b/tests/v1/codegen/rpc/generated/test_service/rpc_method.py index 1e40411f..d839f4af 100644 --- a/tests/v1/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/v1/codegen/rpc/generated/test_service/rpc_method.py @@ -30,6 +30,7 @@ def encode_Rpc_MethodInput( for (k, v) in ( { "data": x.get("data"), + "data2": x.get("data2"), } ).items() if v is not None @@ -38,10 +39,17 @@ def encode_Rpc_MethodInput( class Rpc_MethodInput(TypedDict): data: str + data2: datetime.datetime + + +class Rpc_MethodOutputData_3(BaseModel): + data_test: str | None = Field(alias="data-test", default=None) class Rpc_MethodOutput(BaseModel): data: str + data_3: Rpc_MethodOutputData_3 = Field(alias="data-3") + data2: datetime.datetime Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter( diff --git a/tests/v1/codegen/rpc/invalid-schema.json b/tests/v1/codegen/rpc/invalid-schema.json new file mode 100644 index 00000000..2652f37a --- /dev/null +++ b/tests/v1/codegen/rpc/invalid-schema.json @@ -0,0 +1,30 @@ +{ + "services": { + "test_service": { + "procedures": { + "rpc_method": { + "input": { + "type": "boolean" + }, + "output": { + "type": "object", + "properties": { + "data:3": { + "type": "Date" + }, + "data-3": { + "type": "boolean" + } + }, + "required": ["data:3"] + }, + "errors": { + "not": {} + }, + "type": "rpc" + } + } + } + } + } + \ No newline at end of file diff --git a/tests/v1/codegen/rpc/schema.json b/tests/v1/codegen/rpc/schema.json index 508d354e..1c78df3f 100644 --- a/tests/v1/codegen/rpc/schema.json +++ b/tests/v1/codegen/rpc/schema.json @@ -8,18 +8,33 @@ "properties": { "data": { "type": "string" + }, + "data2": { + "type": "Date" } }, - "required": ["data"] + "required": ["data", "data2"] }, "output": { "type": "object", "properties": { "data": { "type": "string" + }, + "data2": { + "type": "Date" + }, + "data-3": { + "type": "object", + "properties": { + "data-test": { + "type": "string" + } + }, + "required": [] } }, - "required": ["data"] + "required": ["data", "data2", "data-3"] }, "errors": { "not": {} diff --git a/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 2817a039..e875810a 100644 --- a/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -96,7 +96,7 @@ class NeedsenumobjectOutputFooOneOf_out_second(BaseModel): class NeedsenumobjectOutput(BaseModel): - foo: NeedsenumobjectOutputFoo | None = None + foo: NeedsenumobjectOutputFoo | None = Field(default=None) NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter( @@ -105,11 +105,11 @@ class NeedsenumobjectOutput(BaseModel): class NeedsenumobjectErrorsFooAnyOf_0(BaseModel): - beep: Literal["err_first"] | None = None + beep: Literal["err_first"] | None = Field(default=None) class NeedsenumobjectErrorsFooAnyOf_1(BaseModel): - borp: Literal["err_second"] | None = None + borp: Literal["err_second"] | None = Field(default=None) NeedsenumobjectErrorsFoo = Annotated[ @@ -121,7 +121,7 @@ class NeedsenumobjectErrorsFooAnyOf_1(BaseModel): class NeedsenumobjectErrors(RiverError): - foo: NeedsenumobjectErrorsFoo | None = None + foo: NeedsenumobjectErrorsFoo | None = Field(default=None) NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter( diff --git a/tests/v1/codegen/test_invalid_schema.py b/tests/v1/codegen/test_invalid_schema.py new file mode 100644 index 00000000..9a887e6e --- /dev/null +++ b/tests/v1/codegen/test_invalid_schema.py @@ -0,0 +1,27 @@ +from io import StringIO + +import pytest + +from replit_river.codegen.client import schema_to_river_client_codegen + + +def test_field_name_collision_error() -> None: + """Test that codegen raises ValueError for field name collisions.""" + + with pytest.raises(ValueError) as exc_info: + schema_to_river_client_codegen( + read_schema=lambda: open("tests/v1/codegen/rpc/invalid-schema.json"), + target_path="tests/v1/codegen/rpc/generated", + client_name="InvalidClient", + typed_dict_inputs=True, + file_opener=lambda _: StringIO(), + method_filter=None, + protocol_version="v1.1", + ) + + # Check that the error message matches the expected format for field name collision + error_message = str(exc_info.value) + assert "Field name collision" in error_message + assert "data:3" in error_message + assert "data-3" in error_message + assert "all normalize to the same effective name 'data_3'" in error_message diff --git a/tests/v1/codegen/test_rpc.py b/tests/v1/codegen/test_rpc.py index 55837190..e9cd0699 100644 --- a/tests/v1/codegen/test_rpc.py +++ b/tests/v1/codegen/test_rpc.py @@ -2,7 +2,7 @@ import importlib import os import shutil -from datetime import timedelta +from datetime import datetime, timedelta, timezone from pathlib import Path from typing import TextIO @@ -52,6 +52,7 @@ async def test_basic_rpc(client: Client) -> None: res = await RpcClient(client).test_service.rpc_method( { "data": "feep", + "data2": datetime.now(timezone.utc), }, timedelta(seconds=5), ) @@ -80,8 +81,6 @@ async def test_rpc_timeout(client: Client) -> None: with pytest.raises(RiverException): await RpcClient(client).test_service.rpc_method( - { - "data": "feep", - }, + {"data": "feep", "data2": datetime.now(timezone.utc)}, timedelta(milliseconds=200), )