Skip to content

Commit a65041a

Browse files
feat/maybe async handlers (#111)
Why === There are a bunch of needless LSP errors because we do `rpc_method_handler(servicer.Foo, ...)`, where `Foo` returns `ResponseType | Awaitable[ResponseType]`. These types are a little convoluted, but it makes it possible to use grpc-generated "servicer" as-is. What changed ============ Add synchronous response types for river server handlers Test plan ========= I manually patched these methods in production codebases and everything typechecks.
1 parent d4a1727 commit a65041a

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

replit_river/rpc.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def get_response_or_error_payload(
193193

194194

195195
def rpc_method_handler(
196-
method: Callable[[RequestType, grpc.aio.ServicerContext], Awaitable[ResponseType]],
196+
method: Callable[
197+
[RequestType, grpc.aio.ServicerContext], ResponseType | Awaitable[ResponseType]
198+
],
197199
request_deserializer: Callable[[Any], RequestType],
198200
response_serializer: Callable[[ResponseType], Any],
199201
) -> GenericRpcHandler:
@@ -206,7 +208,9 @@ async def wrapped(
206208
try:
207209
context = GrpcContext(peer)
208210
request = request_deserializer(await input.get())
209-
response = await method(request, context)
211+
response = method(request, context)
212+
if isinstance(response, Awaitable):
213+
response = await response
210214
await output.put(
211215
get_response_or_error_payload(response, response_serializer)
212216
)
@@ -247,7 +251,8 @@ async def wrapped(
247251

248252
def subscription_method_handler(
249253
method: Callable[
250-
[RequestType, grpc.aio.ServicerContext], AsyncIterable[ResponseType]
254+
[RequestType, grpc.aio.ServicerContext],
255+
Iterable[ResponseType] | AsyncIterable[ResponseType],
251256
],
252257
request_deserializer: Callable[[Any], RequestType],
253258
response_serializer: Callable[[ResponseType], Any],
@@ -261,10 +266,17 @@ async def wrapped(
261266
try:
262267
context = GrpcContext(peer)
263268
request = request_deserializer(await input.get())
264-
async for response in method(request, context):
265-
await output.put(
266-
get_response_or_error_payload(response, response_serializer)
267-
)
269+
iterator = method(request, context)
270+
if isinstance(iterator, AsyncIterable):
271+
async for response in iterator:
272+
await output.put(
273+
get_response_or_error_payload(response, response_serializer)
274+
)
275+
else:
276+
for response in iterator:
277+
await output.put(
278+
get_response_or_error_payload(response, response_serializer)
279+
)
268280
except grpc.RpcError:
269281
code = grpc.StatusCode(context._abort_code).name if context else "UNKNOWN"
270282
message = (
@@ -300,7 +312,7 @@ async def wrapped(
300312
def upload_method_handler(
301313
method: Callable[
302314
[AsyncIterator[RequestType], grpc.aio.ServicerContext],
303-
Awaitable[ResponseType],
315+
ResponseType | Awaitable[ResponseType],
304316
],
305317
request_deserializer: Callable[[Any], RequestType],
306318
response_serializer: Callable[[ResponseType], Any],
@@ -324,7 +336,9 @@ async def _convert_inputs() -> None:
324336

325337
async def _convert_outputs() -> None:
326338
try:
327-
response = await method(request, context)
339+
response = method(request, context)
340+
if isinstance(response, Awaitable):
341+
response = await response
328342
await output.put(
329343
get_response_or_error_payload(response, response_serializer)
330344
)

0 commit comments

Comments
 (0)