Skip to content

Commit 6edec21

Browse files
committed
encode special fields
1 parent ea06e01 commit 6edec21

File tree

4 files changed

+248
-9
lines changed

4 files changed

+248
-9
lines changed

src/replit_river/codegen/client.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
575575
if name == "$kind":
576576
safe_name = "kind"
577577
else:
578-
safe_name = name
578+
# For TypedDict encoder, use normalized name to access the TypedDict field
579+
# but the output dictionary key should use the original name
580+
if base_model == "TypedDict":
581+
specialized_name = normalize_special_chars(name)
582+
safe_name = specialized_name.lstrip("_") if name != specialized_name else name
583+
else:
584+
safe_name = name
579585
if prop.type == "object" and not prop.patternProperties:
580586
encoder_name = TypeName(
581587
f"encode_{render_literal_type(type_name)}"
@@ -675,14 +681,18 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
675681
effective_name = name
676682
extras = []
677683
if name != specialized_name:
678-
if base_model != "BaseModel":
679-
# TODO: alias support for TypedDict
680-
raise ValueError(
681-
f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501
682-
)
683-
# Pydantic doesn't allow leading underscores in field names
684-
effective_name = specialized_name.lstrip("_")
685-
extras.append(f"alias={repr(name)}")
684+
if base_model == "BaseModel":
685+
# Pydantic doesn't allow leading underscores in field names
686+
effective_name = specialized_name.lstrip("_")
687+
extras.append(f"alias={repr(name)}")
688+
elif base_model == "TypedDict":
689+
# For TypedDict, we use the normalized name directly
690+
# TypedDict doesn't support aliases, so we normalize the field name
691+
effective_name = specialized_name.lstrip("_")
692+
else:
693+
# For RiverError (which extends BaseModel), use alias like BaseModel
694+
effective_name = specialized_name.lstrip("_")
695+
extras.append(f"alias={repr(name)}")
686696

687697
effective_field_names[effective_name].append(name)
688698

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"services": {
3+
"test_service": {
4+
"procedures": {
5+
"rpc_method": {
6+
"input": {
7+
"type": "object",
8+
"properties": {
9+
"data-3": {
10+
"type": "string"
11+
},
12+
"data:3": {
13+
"type": "number"
14+
}
15+
},
16+
"required": ["data-3", "data:3"]
17+
},
18+
"output": {
19+
"type": "boolean"
20+
},
21+
"errors": {
22+
"not": {}
23+
},
24+
"type": "rpc"
25+
}
26+
}
27+
}
28+
}
29+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"services": {
3+
"test_service": {
4+
"procedures": {
5+
"rpc_method": {
6+
"input": {
7+
"type": "object",
8+
"properties": {
9+
"data-field1": {
10+
"type": "string"
11+
},
12+
"data:field2": {
13+
"type": "number"
14+
},
15+
"data.field3": {
16+
"type": "boolean"
17+
},
18+
"data/field4": {
19+
"type": "string"
20+
},
21+
"data@field5": {
22+
"type": "integer"
23+
},
24+
"data field6": {
25+
"type": "string"
26+
}
27+
},
28+
"required": ["data-field1", "data:field2"]
29+
},
30+
"output": {
31+
"type": "boolean"
32+
},
33+
"errors": {
34+
"not": {}
35+
},
36+
"type": "rpc"
37+
}
38+
}
39+
}
40+
}
41+
}
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from io import StringIO
2+
3+
import pytest
4+
5+
from replit_river.codegen.client import schema_to_river_client_codegen
6+
7+
8+
def test_input_special_chars_basemodel() -> None:
9+
"""Test that codegen handles special characters in input field names for BaseModel."""
10+
11+
# Test should pass without raising an exception
12+
schema_to_river_client_codegen(
13+
read_schema=lambda: open("tests/v1/codegen/rpc/input-special-chars-schema.json"),
14+
target_path="tests/v1/codegen/rpc/generated_input_special",
15+
client_name="InputSpecialClient",
16+
typed_dict_inputs=False, # BaseModel inputs
17+
file_opener=lambda _: StringIO(),
18+
method_filter=None,
19+
protocol_version="v1.1",
20+
)
21+
22+
23+
def test_input_special_chars_typeddict() -> None:
24+
"""Test that codegen handles special characters in input field names for TypedDict."""
25+
26+
# Test should pass without raising an exception
27+
schema_to_river_client_codegen(
28+
read_schema=lambda: open("tests/v1/codegen/rpc/input-special-chars-schema.json"),
29+
target_path="tests/v1/codegen/rpc/generated_input_special_td",
30+
client_name="InputSpecialTDClient",
31+
typed_dict_inputs=True, # TypedDict inputs
32+
file_opener=lambda _: StringIO(),
33+
method_filter=None,
34+
protocol_version="v1.1",
35+
)
36+
37+
38+
def test_input_collision_error_basemodel() -> None:
39+
"""Test that codegen raises ValueError for input field name collisions with BaseModel."""
40+
41+
with pytest.raises(ValueError) as exc_info:
42+
schema_to_river_client_codegen(
43+
read_schema=lambda: open("tests/v1/codegen/rpc/input-collision-schema.json"),
44+
target_path="tests/v1/codegen/rpc/generated_input_collision",
45+
client_name="InputCollisionClient",
46+
typed_dict_inputs=False, # BaseModel inputs
47+
file_opener=lambda _: StringIO(),
48+
method_filter=None,
49+
protocol_version="v1.1",
50+
)
51+
52+
# Check that the error message matches the expected format for field name collision
53+
error_message = str(exc_info.value)
54+
assert "Field name collision" in error_message
55+
assert "data-3" in error_message
56+
assert "data:3" in error_message
57+
assert "all normalize to the same effective name 'data_3'" in error_message
58+
59+
60+
def test_input_collision_error_typeddict() -> None:
61+
"""Test that codegen raises ValueError for input field name collisions with TypedDict."""
62+
63+
with pytest.raises(ValueError) as exc_info:
64+
schema_to_river_client_codegen(
65+
read_schema=lambda: open("tests/v1/codegen/rpc/input-collision-schema.json"),
66+
target_path="tests/v1/codegen/rpc/generated_input_collision_td",
67+
client_name="InputCollisionTDClient",
68+
typed_dict_inputs=True, # TypedDict inputs
69+
file_opener=lambda _: StringIO(),
70+
method_filter=None,
71+
protocol_version="v1.1",
72+
)
73+
74+
# Check that the error message matches the expected format for field name collision
75+
error_message = str(exc_info.value)
76+
assert "Field name collision" in error_message
77+
assert "data-3" in error_message
78+
assert "data:3" in error_message
79+
assert "all normalize to the same effective name 'data_3'" in error_message
80+
81+
82+
def test_init_special_chars_basemodel() -> None:
83+
"""Test that codegen handles special characters in init field names for BaseModel."""
84+
85+
init_schema = {
86+
"services": {
87+
"test_service": {
88+
"procedures": {
89+
"stream_method": {
90+
"init": {
91+
"type": "object",
92+
"properties": {
93+
"init-field1": {"type": "string"},
94+
"init:field2": {"type": "number"},
95+
"init.field3": {"type": "boolean"}
96+
},
97+
"required": ["init-field1"]
98+
},
99+
"output": {"type": "boolean"},
100+
"errors": {"not": {}},
101+
"type": "stream"
102+
}
103+
}
104+
}
105+
}
106+
}
107+
108+
import json
109+
110+
# Test should pass without raising an exception
111+
schema_to_river_client_codegen(
112+
read_schema=lambda: StringIO(json.dumps(init_schema)),
113+
target_path="tests/v1/codegen/rpc/generated_init_special",
114+
client_name="InitSpecialClient",
115+
typed_dict_inputs=False, # BaseModel inputs
116+
file_opener=lambda _: StringIO(),
117+
method_filter=None,
118+
protocol_version="v2.0",
119+
)
120+
121+
122+
def test_init_special_chars_typeddict() -> None:
123+
"""Test that codegen handles special characters in init field names for TypedDict."""
124+
125+
init_schema = {
126+
"services": {
127+
"test_service": {
128+
"procedures": {
129+
"stream_method": {
130+
"init": {
131+
"type": "object",
132+
"properties": {
133+
"init-field1": {"type": "string"},
134+
"init:field2": {"type": "number"},
135+
"init.field3": {"type": "boolean"}
136+
},
137+
"required": ["init-field1"]
138+
},
139+
"output": {"type": "boolean"},
140+
"errors": {"not": {}},
141+
"type": "stream"
142+
}
143+
}
144+
}
145+
}
146+
}
147+
148+
import json
149+
150+
# Test should pass without raising an exception
151+
schema_to_river_client_codegen(
152+
read_schema=lambda: StringIO(json.dumps(init_schema)),
153+
target_path="tests/v1/codegen/rpc/generated_init_special_td",
154+
client_name="InitSpecialTDClient",
155+
typed_dict_inputs=True, # TypedDict inputs
156+
file_opener=lambda _: StringIO(),
157+
method_filter=None,
158+
protocol_version="v2.0",
159+
)

0 commit comments

Comments
 (0)