Skip to content

Commit 502278f

Browse files
[tests] handler decomposition (#125)
Why === Striking a balance between test handler reuse and colocating the handler specifiers with the tests that use them. What changed ============ - Decomposing monolithic server handler specifiers into directly DI-ing reusable components - For tests that need bespoke handlers, defining those alongside the methods that use them. Test plan ========= CI
1 parent 88f4347 commit 502278f

File tree

7 files changed

+171
-93
lines changed

7 files changed

+171
-93
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dev-dependencies = [
4646
"pytest-mock>=3.11.1",
4747
"ruff>=0.0.278",
4848
"types-protobuf>=4.24.0.20240311",
49+
"types-nanoid>=2.0.0.20240601",
4950
]
5051

5152
[tool.ruff]

tests/common_handlers.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Any, AsyncGenerator, AsyncIterator, Iterator
2+
3+
import grpc
4+
import grpc.aio
5+
6+
from replit_river.rpc import (
7+
rpc_method_handler,
8+
stream_method_handler,
9+
subscription_method_handler,
10+
upload_method_handler,
11+
)
12+
from tests.conftest import HandlerMapping, deserialize_request, serialize_response
13+
14+
15+
async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
16+
return f"Hello, {request}!"
17+
18+
19+
basic_rpc_method: HandlerMapping = {
20+
("test_service", "rpc_method"): (
21+
"rpc",
22+
rpc_method_handler(rpc_handler, deserialize_request, serialize_response),
23+
)
24+
}
25+
26+
27+
async def upload_handler(
28+
request: Iterator[str] | AsyncIterator[str], context: Any
29+
) -> str:
30+
uploaded_data = []
31+
if isinstance(request, AsyncIterator):
32+
async for data in request:
33+
uploaded_data.append(data)
34+
else:
35+
for data in request:
36+
uploaded_data.append(data)
37+
return f"Uploaded: {', '.join(uploaded_data)}"
38+
39+
40+
basic_upload: HandlerMapping = {
41+
("test_service", "upload_method"): (
42+
"upload",
43+
upload_method_handler(upload_handler, deserialize_request, serialize_response),
44+
),
45+
}
46+
47+
48+
async def subscription_handler(
49+
request: str, context: grpc.aio.ServicerContext
50+
) -> AsyncGenerator[str, None]:
51+
for i in range(5):
52+
yield f"Subscription message {i} for {request}"
53+
54+
55+
basic_subscription: HandlerMapping = {
56+
("test_service", "subscription_method"): (
57+
"subscription",
58+
subscription_method_handler(
59+
subscription_handler, deserialize_request, serialize_response
60+
),
61+
),
62+
}
63+
64+
65+
async def stream_handler(
66+
request: Iterator[str] | AsyncIterator[str],
67+
context: grpc.aio.ServicerContext,
68+
) -> AsyncGenerator[str, None]:
69+
if isinstance(request, AsyncIterator):
70+
async for data in request:
71+
yield f"Stream response for {data}"
72+
else:
73+
for data in request:
74+
yield f"Stream response for {data}"
75+
76+
77+
basic_stream: HandlerMapping = {
78+
("test_service", "stream_method"): (
79+
"stream",
80+
stream_method_handler(stream_handler, deserialize_request, serialize_response),
81+
),
82+
}

tests/conftest.py

Lines changed: 15 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import asyncio
22
import logging
3-
from collections.abc import AsyncIterator
4-
from typing import Any, AsyncGenerator, Iterator, Literal
3+
from typing import Any, AsyncGenerator, Literal, Mapping
54

6-
import grpc.aio
7-
import nanoid # type: ignore
5+
import nanoid
86
import pytest
97
from opentelemetry import trace
108
from opentelemetry.sdk.trace import TracerProvider
@@ -14,13 +12,10 @@
1412

1513
from replit_river.client import Client
1614
from replit_river.client_transport import UriAndMetadata
17-
from replit_river.error_schema import RiverError, RiverException
15+
from replit_river.error_schema import RiverError
1816
from replit_river.rpc import (
17+
GenericRpcHandler,
1918
TransportMessage,
20-
rpc_method_handler,
21-
stream_method_handler,
22-
subscription_method_handler,
23-
upload_method_handler,
2419
)
2520
from replit_river.server import Server
2621
from replit_river.transport_options import TransportOptions
@@ -29,6 +24,8 @@
2924
# Modular fixtures
3025
pytest_plugins = ["tests.river_fixtures.logging"]
3126

27+
HandlerMapping = Mapping[tuple[str, str], tuple[str, GenericRpcHandler]]
28+
3229

3330
def transport_message(
3431
seq: int = 0,
@@ -71,93 +68,22 @@ def deserialize_error(response: dict) -> RiverError:
7168
return RiverError.model_validate(response)
7269

7370

74-
# RPC method handlers for testing
75-
async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
76-
return f"Hello, {request}!"
77-
78-
79-
async def subscription_handler(
80-
request: str, context: grpc.aio.ServicerContext
81-
) -> AsyncGenerator[str, None]:
82-
for i in range(5):
83-
yield f"Subscription message {i} for {request}"
84-
85-
86-
async def upload_handler(
87-
request: Iterator[str] | AsyncIterator[str], context: Any
88-
) -> str:
89-
uploaded_data = []
90-
if isinstance(request, AsyncIterator):
91-
async for data in request:
92-
uploaded_data.append(data)
93-
else:
94-
for data in request:
95-
uploaded_data.append(data)
96-
return f"Uploaded: {', '.join(uploaded_data)}"
97-
98-
99-
async def stream_handler(
100-
request: Iterator[str] | AsyncIterator[str],
101-
context: grpc.aio.ServicerContext,
102-
) -> AsyncGenerator[str, None]:
103-
if isinstance(request, AsyncIterator):
104-
async for data in request:
105-
yield f"Stream response for {data}"
106-
else:
107-
for data in request:
108-
yield f"Stream response for {data}"
109-
110-
111-
async def stream_error_handler(
112-
request: Iterator[str] | AsyncIterator[str],
113-
context: grpc.aio.ServicerContext,
114-
) -> AsyncGenerator[str, None]:
115-
raise RiverException("INJECTED_ERROR", "test error")
116-
yield "test" # appease the type checker
117-
118-
11971
@pytest.fixture
12072
def transport_options() -> TransportOptions:
12173
return TransportOptions()
12274

12375

12476
@pytest.fixture
125-
def server(transport_options: TransportOptions) -> Server:
77+
def server_handlers(handlers: HandlerMapping) -> HandlerMapping:
78+
return handlers
79+
80+
81+
@pytest.fixture
82+
def server(
83+
transport_options: TransportOptions, server_handlers: HandlerMapping
84+
) -> Server:
12685
server = Server(server_id="test_server", transport_options=transport_options)
127-
server.add_rpc_handlers(
128-
{
129-
("test_service", "rpc_method"): (
130-
"rpc",
131-
rpc_method_handler(
132-
rpc_handler, deserialize_request, serialize_response
133-
),
134-
),
135-
("test_service", "subscription_method"): (
136-
"subscription",
137-
subscription_method_handler(
138-
subscription_handler, deserialize_request, serialize_response
139-
),
140-
),
141-
("test_service", "upload_method"): (
142-
"upload",
143-
upload_method_handler(
144-
upload_handler, deserialize_request, serialize_response
145-
),
146-
),
147-
("test_service", "stream_method"): (
148-
"stream",
149-
stream_method_handler(
150-
stream_handler, deserialize_request, serialize_response
151-
),
152-
),
153-
("test_service", "stream_method_error"): (
154-
"stream",
155-
stream_method_handler(
156-
stream_error_handler, deserialize_request, serialize_response
157-
),
158-
),
159-
}
160-
)
86+
server.add_rpc_handlers(server_handlers)
16187
return server
16288

16389

tests/test_communication.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,21 @@
66
from replit_river.client import Client
77
from replit_river.error_schema import RiverError
88
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
9-
from tests.conftest import deserialize_error, deserialize_response, serialize_request
9+
from tests.common_handlers import (
10+
basic_rpc_method,
11+
basic_stream,
12+
basic_subscription,
13+
basic_upload,
14+
)
15+
from tests.conftest import (
16+
deserialize_error,
17+
deserialize_response,
18+
serialize_request,
19+
)
1020

1121

1222
@pytest.mark.asyncio
23+
@pytest.mark.parametrize("handlers", [{**basic_rpc_method}])
1324
async def test_rpc_method(client: Client) -> None:
1425
response = await client.send_rpc(
1526
"test_service",
@@ -23,6 +34,7 @@ async def test_rpc_method(client: Client) -> None:
2334

2435

2536
@pytest.mark.asyncio
37+
@pytest.mark.parametrize("handlers", [{**basic_upload}])
2638
async def test_upload_method(client: Client) -> None:
2739
async def upload_data() -> AsyncGenerator[str, None]:
2840
yield "Data 1"
@@ -43,6 +55,7 @@ async def upload_data() -> AsyncGenerator[str, None]:
4355

4456

4557
@pytest.mark.asyncio
58+
@pytest.mark.parametrize("handlers", [{**basic_upload}])
4659
async def test_upload_more_than_send_buffer_max(client: Client) -> None:
4760
iterations = MAX_MESSAGE_BUFFER_SIZE * 2
4861

@@ -64,6 +77,7 @@ async def upload_data() -> AsyncGenerator[str, None]:
6477

6578

6679
@pytest.mark.asyncio
80+
@pytest.mark.parametrize("handlers", [{**basic_upload}])
6781
async def test_upload_empty(client: Client) -> None:
6882
async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
6983
if enabled:
@@ -83,6 +97,7 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
8397

8498

8599
@pytest.mark.asyncio
100+
@pytest.mark.parametrize("handlers", [{**basic_subscription}])
86101
async def test_subscription_method(client: Client) -> None:
87102
async for response in client.send_subscription(
88103
"test_service",
@@ -97,6 +112,7 @@ async def test_subscription_method(client: Client) -> None:
97112

98113

99114
@pytest.mark.asyncio
115+
@pytest.mark.parametrize("handlers", [{**basic_stream}])
100116
async def test_stream_method(client: Client) -> None:
101117
async def stream_data() -> AsyncGenerator[str, None]:
102118
yield "Stream 1"
@@ -125,6 +141,7 @@ async def stream_data() -> AsyncGenerator[str, None]:
125141

126142

127143
@pytest.mark.asyncio
144+
@pytest.mark.parametrize("handlers", [{**basic_stream}])
128145
async def test_stream_empty(client: Client) -> None:
129146
async def stream_data(enabled: bool = False) -> AsyncGenerator[str, None]:
130147
if enabled:
@@ -147,6 +164,7 @@ async def stream_data(enabled: bool = False) -> AsyncGenerator[str, None]:
147164

148165

149166
@pytest.mark.asyncio
167+
@pytest.mark.parametrize("handlers", [{**basic_upload, **basic_stream}])
150168
async def test_multiplexing(client: Client) -> None:
151169
async def upload_data() -> AsyncGenerator[str, None]:
152170
yield "Upload Data 1"

tests/test_handshake.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def transport_options() -> TransportOptions:
1515

1616

1717
@pytest.mark.asyncio
18+
@pytest.mark.parametrize("handlers", [{}])
1819
async def test_handshake_timeout(server: Server) -> None:
1920
async with serve(server.serve, "localhost", 8765):
2021
start = time()

0 commit comments

Comments
 (0)