Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
if name == "$kind":
safe_name = "kind"
else:
safe_name = name
# For TypedDict encoder, use normalized name to access the TypedDict field
# but the output dictionary key should use the original name
if base_model == "TypedDict":
specialized_name = normalize_special_chars(name)
safe_name = specialized_name.lstrip("_") if name != specialized_name else name
else:
safe_name = name
if prop.type == "object" and not prop.patternProperties:
encoder_name = TypeName(
f"encode_{render_literal_type(type_name)}"
Expand Down Expand Up @@ -675,14 +681,18 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
effective_name = name
extras = []
if name != specialized_name:
if base_model != "BaseModel":
# TODO: alias support for TypedDict
raise ValueError(
f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501
)
# Pydantic doesn't allow leading underscores in field names
effective_name = specialized_name.lstrip("_")
extras.append(f"alias={repr(name)}")
if base_model == "BaseModel":
# Pydantic doesn't allow leading underscores in field names
effective_name = specialized_name.lstrip("_")
extras.append(f"alias={repr(name)}")
elif base_model == "TypedDict":
# For TypedDict, we use the normalized name directly
# TypedDict doesn't support aliases, so we normalize the field name
effective_name = specialized_name.lstrip("_")
else:
# For RiverError (which extends BaseModel), use alias like BaseModel
effective_name = specialized_name.lstrip("_")
extras.append(f"alias={repr(name)}")

effective_field_names[effective_name].append(name)

Expand Down
29 changes: 29 additions & 0 deletions tests/v1/codegen/rpc/input-collision-schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"services": {
"test_service": {
"procedures": {
"rpc_method": {
"input": {
"type": "object",
"properties": {
"data-3": {
"type": "string"
},
"data:3": {
"type": "number"
}
},
"required": ["data-3", "data:3"]
},
"output": {
"type": "boolean"
},
"errors": {
"not": {}
},
"type": "rpc"
}
}
}
}
}
41 changes: 41 additions & 0 deletions tests/v1/codegen/rpc/input-special-chars-schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"services": {
"test_service": {
"procedures": {
"rpc_method": {
"input": {
"type": "object",
"properties": {
"data-field1": {
"type": "string"
},
"data:field2": {
"type": "number"
},
"data.field3": {
"type": "boolean"
},
"data/field4": {
"type": "string"
},
"data@field5": {
"type": "integer"
},
"data field6": {
"type": "string"
}
},
"required": ["data-field1", "data:field2"]
},
"output": {
"type": "boolean"
},
"errors": {
"not": {}
},
"type": "rpc"
}
}
}
}
}
159 changes: 159 additions & 0 deletions tests/v1/codegen/test_input_special_chars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from io import StringIO

import pytest

from replit_river.codegen.client import schema_to_river_client_codegen


def test_input_special_chars_basemodel() -> None:
"""Test that codegen handles special characters in input field names for BaseModel."""

# Test should pass without raising an exception
schema_to_river_client_codegen(
read_schema=lambda: open("tests/v1/codegen/rpc/input-special-chars-schema.json"),
target_path="tests/v1/codegen/rpc/generated_input_special",
client_name="InputSpecialClient",
typed_dict_inputs=False, # BaseModel inputs
file_opener=lambda _: StringIO(),
method_filter=None,
protocol_version="v1.1",
)


def test_input_special_chars_typeddict() -> None:
"""Test that codegen handles special characters in input field names for TypedDict."""

# Test should pass without raising an exception
schema_to_river_client_codegen(
read_schema=lambda: open("tests/v1/codegen/rpc/input-special-chars-schema.json"),
target_path="tests/v1/codegen/rpc/generated_input_special_td",
client_name="InputSpecialTDClient",
typed_dict_inputs=True, # TypedDict inputs
file_opener=lambda _: StringIO(),
method_filter=None,
protocol_version="v1.1",
)


def test_input_collision_error_basemodel() -> None:
"""Test that codegen raises ValueError for input field name collisions with BaseModel."""

with pytest.raises(ValueError) as exc_info:
schema_to_river_client_codegen(
read_schema=lambda: open("tests/v1/codegen/rpc/input-collision-schema.json"),
target_path="tests/v1/codegen/rpc/generated_input_collision",
client_name="InputCollisionClient",
typed_dict_inputs=False, # BaseModel inputs
file_opener=lambda _: StringIO(),
method_filter=None,
protocol_version="v1.1",
)

# Check that the error message matches the expected format for field name collision
error_message = str(exc_info.value)
assert "Field name collision" in error_message
assert "data-3" in error_message
assert "data:3" in error_message
assert "all normalize to the same effective name 'data_3'" in error_message


def test_input_collision_error_typeddict() -> None:
"""Test that codegen raises ValueError for input field name collisions with TypedDict."""

with pytest.raises(ValueError) as exc_info:
schema_to_river_client_codegen(
read_schema=lambda: open("tests/v1/codegen/rpc/input-collision-schema.json"),
target_path="tests/v1/codegen/rpc/generated_input_collision_td",
client_name="InputCollisionTDClient",
typed_dict_inputs=True, # TypedDict inputs
file_opener=lambda _: StringIO(),
method_filter=None,
protocol_version="v1.1",
)

# Check that the error message matches the expected format for field name collision
error_message = str(exc_info.value)
assert "Field name collision" in error_message
assert "data-3" in error_message
assert "data:3" in error_message
assert "all normalize to the same effective name 'data_3'" in error_message


def test_init_special_chars_basemodel() -> None:
"""Test that codegen handles special characters in init field names for BaseModel."""

init_schema = {
"services": {
"test_service": {
"procedures": {
"stream_method": {
"init": {
"type": "object",
"properties": {
"init-field1": {"type": "string"},
"init:field2": {"type": "number"},
"init.field3": {"type": "boolean"}
},
"required": ["init-field1"]
},
"output": {"type": "boolean"},
"errors": {"not": {}},
"type": "stream"
}
}
}
}
}

import json

# Test should pass without raising an exception
schema_to_river_client_codegen(
read_schema=lambda: StringIO(json.dumps(init_schema)),
target_path="tests/v1/codegen/rpc/generated_init_special",
client_name="InitSpecialClient",
typed_dict_inputs=False, # BaseModel inputs
file_opener=lambda _: StringIO(),
method_filter=None,
protocol_version="v2.0",
)


def test_init_special_chars_typeddict() -> None:
"""Test that codegen handles special characters in init field names for TypedDict."""

init_schema = {
"services": {
"test_service": {
"procedures": {
"stream_method": {
"init": {
"type": "object",
"properties": {
"init-field1": {"type": "string"},
"init:field2": {"type": "number"},
"init.field3": {"type": "boolean"}
},
"required": ["init-field1"]
},
"output": {"type": "boolean"},
"errors": {"not": {}},
"type": "stream"
}
}
}
}
}

import json

# Test should pass without raising an exception
schema_to_river_client_codegen(
read_schema=lambda: StringIO(json.dumps(init_schema)),
target_path="tests/v1/codegen/rpc/generated_init_special_td",
client_name="InitSpecialTDClient",
typed_dict_inputs=True, # TypedDict inputs
file_opener=lambda _: StringIO(),
method_filter=None,
protocol_version="v2.0",
)
Loading