Skip to content

Commit 10be03c

Browse files
committed
another e2e test
1 parent a1b0ded commit 10be03c

File tree

1 file changed

+97
-1
lines changed

1 file changed

+97
-1
lines changed

tests/v1/codegen/test_rpc.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,31 @@ def file_opener(path: Path) -> TextIO:
3838

3939

4040
@pytest.fixture(scope="session", autouse=True)
41-
def reload_rpc_import(generate_rpc_client: None) -> None:
41+
def generate_special_chars_client() -> None:
42+
shutil.rmtree("tests/v1/codegen/rpc/generated_special_chars", ignore_errors=True)
43+
os.makedirs("tests/v1/codegen/rpc/generated_special_chars")
44+
45+
def file_opener(path: Path) -> TextIO:
46+
return open(path, "w")
47+
48+
schema_to_river_client_codegen(
49+
read_schema=lambda: open("tests/v1/codegen/rpc/input-special-chars-schema.json"), # noqa: E501
50+
target_path="tests/v1/codegen/rpc/generated_special_chars",
51+
client_name="SpecialCharsClient",
52+
typed_dict_inputs=True,
53+
file_opener=file_opener,
54+
method_filter=None,
55+
protocol_version="v1.1",
56+
)
57+
58+
59+
@pytest.fixture(scope="session", autouse=True)
60+
def reload_rpc_import(generate_rpc_client: None, generate_special_chars_client: None) -> None: # noqa: E501
4261
import tests.v1.codegen.rpc.generated
62+
import tests.v1.codegen.rpc.generated_special_chars
4363

4464
importlib.reload(tests.v1.codegen.rpc.generated)
65+
importlib.reload(tests.v1.codegen.rpc.generated_special_chars)
4566

4667

4768
@pytest.mark.asyncio
@@ -74,6 +95,50 @@ async def rpc_timeout_handler(request: str, context: grpc.aio.ServicerContext) -
7495
}
7596

7697

98+
def deserialize_special_chars_request(request: dict) -> dict:
99+
"""Deserialize request for special chars test - pass through the full dict."""
100+
return request
101+
102+
103+
def serialize_special_chars_response(response: bool) -> dict:
104+
"""Serialize response for special chars test."""
105+
return response
106+
107+
108+
async def special_chars_handler(request: dict, context: grpc.aio.ServicerContext) -> bool: # noqa: E501
109+
"""Handler that processes input with special character field names."""
110+
# The request comes with original field names (with special characters)
111+
# as they are sent over the wire before normalization
112+
113+
# Verify we received the required fields with their original names
114+
required_fields = ["data-field1", "data:field2"]
115+
116+
for field in required_fields:
117+
if field not in request:
118+
raise ValueError(f"Missing required field: {field}. Available keys: {list(request.keys())}") # noqa: E501
119+
120+
# Verify the values are of expected types
121+
if not isinstance(request["data-field1"], str):
122+
raise ValueError("data-field1 should be a string")
123+
if not isinstance(request["data:field2"], (int, float)):
124+
raise ValueError("data:field2 should be a number")
125+
126+
# Return True if all expected data is present and valid
127+
return True
128+
129+
130+
special_chars_method: HandlerMapping = {
131+
("test_service", "rpc_method"): (
132+
"rpc",
133+
rpc_method_handler(
134+
special_chars_handler,
135+
deserialize_special_chars_request,
136+
serialize_special_chars_response,
137+
),
138+
)
139+
}
140+
141+
77142
@pytest.mark.asyncio
78143
@pytest.mark.parametrize("handlers", [{**rpc_timeout_method}])
79144
async def test_rpc_timeout(client: Client) -> None:
@@ -84,3 +149,34 @@ async def test_rpc_timeout(client: Client) -> None:
84149
{"data": "feep", "data2": datetime.now(timezone.utc)},
85150
timedelta(milliseconds=200),
86151
)
152+
153+
154+
@pytest.mark.asyncio
155+
@pytest.mark.parametrize("handlers", [{**special_chars_method}])
156+
async def test_special_chars_rpc(client: Client) -> None:
157+
"""Test RPC method with special characters in field names."""
158+
from tests.v1.codegen.rpc.generated_special_chars import SpecialCharsClient
159+
160+
# Test with all fields including optional ones
161+
result = await SpecialCharsClient(client).test_service.rpc_method(
162+
{
163+
"data_field1": "test_value1", # Required: data-field1 -> data_field1
164+
"data_field2": 42.5, # Required: data:field2 -> data_field2
165+
"data_field3": True, # Optional: data.field3 -> data_field3
166+
"data_field4": "test_value4", # Optional: data/field4 -> data_field4
167+
"data_field5": 123, # Optional: data@field5 -> data_field5
168+
"data_field6": "test_value6", # Optional: data field6 -> data_field6
169+
},
170+
timedelta(seconds=5),
171+
)
172+
assert result is True
173+
174+
# Test with only required fields
175+
result = await SpecialCharsClient(client).test_service.rpc_method(
176+
{
177+
"data_field1": "required_value",
178+
"data_field2": 99.9,
179+
},
180+
timedelta(seconds=5),
181+
)
182+
assert result is True

0 commit comments

Comments
 (0)