Skip to content

Commit a24ba26

Browse files
Missing some signatures from #111 (#112)
Why === I handled the output types, but not the input types. This is arguably more onerous for server implementations, but this still does conform to how grpc server codegen types work today. What changed ============ For lack of a stable way to say "I will be providing you an `AsyncIterator[RequestType]`, I don't care if you take other types as well" we should at the very least conform to the type signature exposed by grpc codegen. Test plan ========= Does bumping this library cause existing modules to typecheck correctly?
1 parent a65041a commit a24ba26

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

replit_river/rpc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Dict,
1010
Generic,
1111
Iterable,
12+
Iterator,
1213
Literal,
1314
Mapping,
1415
NoReturn,
@@ -311,7 +312,7 @@ async def wrapped(
311312

312313
def upload_method_handler(
313314
method: Callable[
314-
[AsyncIterator[RequestType], grpc.aio.ServicerContext],
315+
[Iterator[RequestType] | AsyncIterator[RequestType], grpc.aio.ServicerContext],
315316
ResponseType | Awaitable[ResponseType],
316317
],
317318
request_deserializer: Callable[[Any], RequestType],
@@ -388,7 +389,7 @@ async def _convert_outputs() -> None:
388389

389390
def stream_method_handler(
390391
method: Callable[
391-
[AsyncIterator[RequestType], grpc.aio.ServicerContext],
392+
[Iterator[RequestType] | AsyncIterator[RequestType], grpc.aio.ServicerContext],
392393
AsyncIterable[ResponseType],
393394
],
394395
request_deserializer: Callable[[Any], RequestType],

tests/conftest.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
from collections.abc import AsyncIterator
4-
from typing import Any, AsyncGenerator, Literal
4+
from typing import Any, AsyncGenerator, Iterator, Literal
55

66
import nanoid # type: ignore
77
import pytest
@@ -79,18 +79,28 @@ async def subscription_handler(
7979
yield f"Subscription message {i} for {request}"
8080

8181

82-
async def upload_handler(request: AsyncIterator[str], context: Any) -> str:
82+
async def upload_handler(
83+
request: Iterator[str] | AsyncIterator[str], context: Any
84+
) -> str:
8385
uploaded_data = []
84-
async for data in request:
85-
uploaded_data.append(data)
86+
if isinstance(request, AsyncIterator):
87+
async for data in request:
88+
uploaded_data.append(data)
89+
else:
90+
for data in request:
91+
uploaded_data.append(data)
8692
return f"Uploaded: {', '.join(uploaded_data)}"
8793

8894

8995
async def stream_handler(
90-
request: AsyncIterator[str], context: GrpcContext
96+
request: Iterator[str] | AsyncIterator[str], context: GrpcContext
9197
) -> AsyncGenerator[str, None]:
92-
async for data in request:
93-
yield f"Stream response for {data}"
98+
if isinstance(request, AsyncIterator):
99+
async for data in request:
100+
yield f"Stream response for {data}"
101+
else:
102+
for data in request:
103+
yield f"Stream response for {data}"
94104

95105

96106
@pytest.fixture

0 commit comments

Comments
 (0)