@@ -38,10 +38,31 @@ def file_opener(path: Path) -> TextIO:
3838
3939
4040@pytest .fixture (scope = "session" , autouse = True )
41- def reload_rpc_import (generate_rpc_client : None ) -> None :
41+ def generate_special_chars_client () -> None :
42+ shutil .rmtree ("tests/v1/codegen/rpc/generated_special_chars" , ignore_errors = True )
43+ os .makedirs ("tests/v1/codegen/rpc/generated_special_chars" )
44+
45+ def file_opener (path : Path ) -> TextIO :
46+ return open (path , "w" )
47+
48+ schema_to_river_client_codegen (
49+ read_schema = lambda : open ("tests/v1/codegen/rpc/input-special-chars-schema.json" ), # noqa: E501
50+ target_path = "tests/v1/codegen/rpc/generated_special_chars" ,
51+ client_name = "SpecialCharsClient" ,
52+ typed_dict_inputs = True ,
53+ file_opener = file_opener ,
54+ method_filter = None ,
55+ protocol_version = "v1.1" ,
56+ )
57+
58+
59+ @pytest .fixture (scope = "session" , autouse = True )
60+ def reload_rpc_import (generate_rpc_client : None , generate_special_chars_client : None ) -> None : # noqa: E501
4261 import tests .v1 .codegen .rpc .generated
62+ import tests .v1 .codegen .rpc .generated_special_chars
4363
4464 importlib .reload (tests .v1 .codegen .rpc .generated )
65+ importlib .reload (tests .v1 .codegen .rpc .generated_special_chars )
4566
4667
4768@pytest .mark .asyncio
@@ -74,6 +95,50 @@ async def rpc_timeout_handler(request: str, context: grpc.aio.ServicerContext) -
7495}
7596
7697
98+ def deserialize_special_chars_request (request : dict ) -> dict :
99+ """Deserialize request for special chars test - pass through the full dict."""
100+ return request
101+
102+
103+ def serialize_special_chars_response (response : bool ) -> dict :
104+ """Serialize response for special chars test."""
105+ return response
106+
107+
108+ async def special_chars_handler (request : dict , context : grpc .aio .ServicerContext ) -> bool : # noqa: E501
109+ """Handler that processes input with special character field names."""
110+ # The request comes with original field names (with special characters)
111+ # as they are sent over the wire before normalization
112+
113+ # Verify we received the required fields with their original names
114+ required_fields = ["data-field1" , "data:field2" ]
115+
116+ for field in required_fields :
117+ if field not in request :
118+ raise ValueError (f"Missing required field: { field } . Available keys: { list (request .keys ())} " ) # noqa: E501
119+
120+ # Verify the values are of expected types
121+ if not isinstance (request ["data-field1" ], str ):
122+ raise ValueError ("data-field1 should be a string" )
123+ if not isinstance (request ["data:field2" ], (int , float )):
124+ raise ValueError ("data:field2 should be a number" )
125+
126+ # Return True if all expected data is present and valid
127+ return True
128+
129+
130+ special_chars_method : HandlerMapping = {
131+ ("test_service" , "rpc_method" ): (
132+ "rpc" ,
133+ rpc_method_handler (
134+ special_chars_handler ,
135+ deserialize_special_chars_request ,
136+ serialize_special_chars_response ,
137+ ),
138+ )
139+ }
140+
141+
77142@pytest .mark .asyncio
78143@pytest .mark .parametrize ("handlers" , [{** rpc_timeout_method }])
79144async def test_rpc_timeout (client : Client ) -> None :
@@ -84,3 +149,34 @@ async def test_rpc_timeout(client: Client) -> None:
84149 {"data" : "feep" , "data2" : datetime .now (timezone .utc )},
85150 timedelta (milliseconds = 200 ),
86151 )
152+
153+
154+ @pytest .mark .asyncio
155+ @pytest .mark .parametrize ("handlers" , [{** special_chars_method }])
156+ async def test_special_chars_rpc (client : Client ) -> None :
157+ """Test RPC method with special characters in field names."""
158+ from tests .v1 .codegen .rpc .generated_special_chars import SpecialCharsClient
159+
160+ # Test with all fields including optional ones
161+ result = await SpecialCharsClient (client ).test_service .rpc_method (
162+ {
163+ "data_field1" : "test_value1" , # Required: data-field1 -> data_field1
164+ "data_field2" : 42.5 , # Required: data:field2 -> data_field2
165+ "data_field3" : True , # Optional: data.field3 -> data_field3
166+ "data_field4" : "test_value4" , # Optional: data/field4 -> data_field4
167+ "data_field5" : 123 , # Optional: data@field5 -> data_field5
168+ "data_field6" : "test_value6" , # Optional: data field6 -> data_field6
169+ },
170+ timedelta (seconds = 5 ),
171+ )
172+ assert result is True
173+
174+ # Test with only required fields
175+ result = await SpecialCharsClient (client ).test_service .rpc_method (
176+ {
177+ "data_field1" : "required_value" ,
178+ "data_field2" : 99.9 ,
179+ },
180+ timedelta (seconds = 5 ),
181+ )
182+ assert result is True
0 commit comments