Skip to content

Commit 437fece

Browse files
committed
Fix optional and add error on field conflict
1 parent 0647cf7 commit 437fece

File tree

8 files changed

+137
-27
lines changed

8 files changed

+137
-27
lines changed

src/replit_river/codegen/client.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import re
33
import subprocess
4+
from collections import defaultdict
45
from pathlib import Path
56
from textwrap import dedent
67
from typing import (
@@ -529,6 +530,9 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
529530
# lambda x: ... vs lambda _: {}
530531
needs_binding = False
531532
encoder_names = set()
533+
# Track effective field names to detect collisions after normalization
534+
# Maps effective name -> list of original field names
535+
effective_field_names: defaultdict[str, list[str]] = defaultdict(list)
532536
if type.properties:
533537
needs_binding = True
534538
typeddict_encoder.append("{")
@@ -658,19 +662,37 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
658662
value = ""
659663
if base_model != "TypedDict":
660664
value = f"= {field_value}"
665+
# Track $kind -> "kind" mapping for collision detection
666+
effective_field_names["kind"].append(name)
667+
661668
current_chunks.append(
662669
f" kind: Annotated[{render_type_expr(type_name)}, Field(alias={
663670
repr(name)
664671
})]{value}"
665672
)
666673
else:
674+
specialized_name = normalize_special_chars(name)
675+
effective_name = name
676+
extras = []
677+
if name != specialized_name:
678+
if base_model != "BaseModel":
679+
# TODO: alias support for TypedDict
680+
raise ValueError(
681+
f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501
682+
)
683+
# Pydantic doesn't allow leading underscores in field names
684+
effective_name = specialized_name.lstrip("_")
685+
extras.append(f"alias={repr(name)}")
686+
687+
effective_field_names[effective_name].append(name)
688+
667689
if name not in type.required:
668690
if base_model == "TypedDict":
669691
current_chunks.append(
670692
reindent(
671693
" ",
672694
f"""\
673-
{name}: NotRequired[{
695+
{effective_name}: NotRequired[{
674696
render_type_expr(
675697
UnionTypeExpr([type_name, NoneTypeExpr()])
676698
)
@@ -679,11 +701,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
679701
)
680702
)
681703
else:
704+
extras.append("default=None")
705+
682706
current_chunks.append(
683707
reindent(
684708
" ",
685709
f"""\
686-
{name}: {
710+
{effective_name}: {
687711
render_type_expr(
688712
UnionTypeExpr(
689713
[
@@ -692,28 +716,30 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
692716
]
693717
)
694718
)
695-
} = None
719+
} = Field({", ".join(extras)})
696720
""",
697721
)
698722
)
699723
else:
700-
specialized_name = normalize_special_chars(name)
701-
effective_name = name
702-
extras = ""
703-
if name != specialized_name:
704-
if base_model != "BaseModel":
705-
# TODO: alias support for TypedDict
706-
raise ValueError(
707-
f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501
708-
)
709-
# Pydantic doesn't allow leading underscores in field names
710-
effective_name = specialized_name.lstrip("_")
711-
extras = f" = Field(serialization_alias={repr(name)})"
724+
extras_str = ""
725+
if len(extras) != 0:
726+
extras_str = f" = Field({', '.join(extras)})"
712727

713728
current_chunks.append(
714-
f" {effective_name}: {render_type_expr(type_name)}{extras}"
729+
f" {effective_name}: {render_type_expr(type_name)}{extras_str}" # noqa: E501
715730
)
716731
typeddict_encoder.append(",")
732+
733+
# Check for field name collisions after processing all fields
734+
for effective_name, original_names in effective_field_names.items():
735+
if len(original_names) > 1:
736+
error_msg = (
737+
f"Field name collision: fields {original_names} all normalize "
738+
f"to the same effective name '{effective_name}'"
739+
)
740+
741+
raise ValueError(error_msg)
742+
717743
typeddict_encoder.append("}")
718744
# exclude_none
719745
typeddict_encoder = (

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime, timezone
12
from typing import Any, Literal, Mapping
23

34
import nanoid
@@ -55,7 +56,11 @@ def deserialize_request(request: dict) -> str:
5556

5657

5758
def serialize_response(response: str) -> dict:
58-
return {"data": response}
59+
return {
60+
"data": response,
61+
"data2": datetime.now(timezone.utc),
62+
"data-3": {"data-test": "test"},
63+
}
5964

6065

6166
def deserialize_response(response: dict) -> str:

tests/v1/codegen/rpc/generated/test_service/rpc_method.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def encode_Rpc_MethodInput(
3030
for (k, v) in (
3131
{
3232
"data": x.get("data"),
33+
"data2": x.get("data2"),
3334
}
3435
).items()
3536
if v is not None
@@ -38,10 +39,17 @@ def encode_Rpc_MethodInput(
3839

3940
class Rpc_MethodInput(TypedDict):
4041
data: str
42+
data2: datetime.datetime
43+
44+
45+
class Rpc_MethodOutputData_3(BaseModel):
46+
data_test: str | None = Field(alias="data-test", default=None)
4147

4248

4349
class Rpc_MethodOutput(BaseModel):
4450
data: str
51+
data_3: Rpc_MethodOutputData_3 = Field(alias="data-3")
52+
data2: datetime.datetime
4553

4654

4755
Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"services": {
3+
"test_service": {
4+
"procedures": {
5+
"rpc_method": {
6+
"input": {
7+
"type": "boolean"
8+
},
9+
"output": {
10+
"type": "object",
11+
"properties": {
12+
"data:3": {
13+
"type": "Date"
14+
},
15+
"data-3": {
16+
"type": "boolean"
17+
}
18+
},
19+
"required": ["data:3"]
20+
},
21+
"errors": {
22+
"not": {}
23+
},
24+
"type": "rpc"
25+
}
26+
}
27+
}
28+
}
29+
}
30+

tests/v1/codegen/rpc/schema.json

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,33 @@
88
"properties": {
99
"data": {
1010
"type": "string"
11+
},
12+
"data2": {
13+
"type": "Date"
1114
}
1215
},
13-
"required": ["data"]
16+
"required": ["data", "data2"]
1417
},
1518
"output": {
1619
"type": "object",
1720
"properties": {
1821
"data": {
1922
"type": "string"
23+
},
24+
"data2": {
25+
"type": "Date"
26+
},
27+
"data-3": {
28+
"type": "object",
29+
"properties": {
30+
"data-test": {
31+
"type": "string"
32+
}
33+
},
34+
"required": []
2035
}
2136
},
22-
"required": ["data"]
37+
"required": ["data", "data2", "data-3"]
2338
},
2439
"errors": {
2540
"not": {}

tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class NeedsenumobjectOutputFooOneOf_out_second(BaseModel):
9696

9797

9898
class NeedsenumobjectOutput(BaseModel):
99-
foo: NeedsenumobjectOutputFoo | None = None
99+
foo: NeedsenumobjectOutputFoo | None = Field(default=None)
100100

101101

102102
NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter(
@@ -105,11 +105,11 @@ class NeedsenumobjectOutput(BaseModel):
105105

106106

107107
class NeedsenumobjectErrorsFooAnyOf_0(BaseModel):
108-
beep: Literal["err_first"] | None = None
108+
beep: Literal["err_first"] | None = Field(default=None)
109109

110110

111111
class NeedsenumobjectErrorsFooAnyOf_1(BaseModel):
112-
borp: Literal["err_second"] | None = None
112+
borp: Literal["err_second"] | None = Field(default=None)
113113

114114

115115
NeedsenumobjectErrorsFoo = Annotated[
@@ -121,7 +121,7 @@ class NeedsenumobjectErrorsFooAnyOf_1(BaseModel):
121121

122122

123123
class NeedsenumobjectErrors(RiverError):
124-
foo: NeedsenumobjectErrorsFoo | None = None
124+
foo: NeedsenumobjectErrorsFoo | None = Field(default=None)
125125

126126

127127
NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter(
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from io import StringIO
2+
3+
import pytest
4+
5+
from replit_river.codegen.client import schema_to_river_client_codegen
6+
7+
8+
def test_field_name_collision_error() -> None:
9+
"""Test that codegen raises ValueError for field name collisions."""
10+
11+
with pytest.raises(ValueError) as exc_info:
12+
schema_to_river_client_codegen(
13+
read_schema=lambda: open("tests/v1/codegen/rpc/invalid-schema.json"),
14+
target_path="tests/v1/codegen/rpc/generated",
15+
client_name="InvalidClient",
16+
typed_dict_inputs=True,
17+
file_opener=lambda _: StringIO(),
18+
method_filter=None,
19+
protocol_version="v1.1",
20+
)
21+
22+
# Check that the error message matches the expected format for field name collision
23+
error_message = str(exc_info.value)
24+
assert "Field name collision" in error_message
25+
assert "data:3" in error_message
26+
assert "data-3" in error_message
27+
assert "all normalize to the same effective name 'data_3'" in error_message

tests/v1/codegen/test_rpc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import importlib
33
import os
44
import shutil
5-
from datetime import timedelta
5+
from datetime import datetime, timedelta, timezone
66
from pathlib import Path
77
from typing import TextIO
88

@@ -52,6 +52,7 @@ async def test_basic_rpc(client: Client) -> None:
5252
res = await RpcClient(client).test_service.rpc_method(
5353
{
5454
"data": "feep",
55+
"data2": datetime.now(timezone.utc),
5556
},
5657
timedelta(seconds=5),
5758
)
@@ -80,8 +81,6 @@ async def test_rpc_timeout(client: Client) -> None:
8081

8182
with pytest.raises(RiverException):
8283
await RpcClient(client).test_service.rpc_method(
83-
{
84-
"data": "feep",
85-
},
84+
{"data": "feep", "data2": datetime.now(timezone.utc)},
8685
timedelta(milliseconds=200),
8786
)

0 commit comments

Comments
 (0)