Skip to content

Commit 320fb9b

Browse files
committed
Add some test
1 parent a2dfcb9 commit 320fb9b

File tree

6 files changed

+210
-27
lines changed

6 files changed

+210
-27
lines changed

replit_river/client.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Generator, Generic, Literal, Optional, Union
55

66
from opentelemetry import trace
7+
from opentelemetry.trace import Span, SpanKind, StatusCode
78

89
from replit_river.client_transport import ClientTransport
910
from replit_river.error_schema import RiverError, RiverException
@@ -118,7 +119,7 @@ async def send_subscription(
118119
span,
119120
):
120121
if isinstance(msg, RiverError):
121-
_record_river_error(msg)
122+
_record_river_error(span, msg)
122123
yield msg # type: ignore # https://github.com/python/mypy/issues/10817
123124

124125
async def send_stream(
@@ -146,30 +147,29 @@ async def send_stream(
146147
span,
147148
):
148149
if isinstance(msg, RiverError):
149-
_record_river_error(msg)
150+
_record_river_error(span, msg)
150151
yield msg # type: ignore # https://github.com/python/mypy/issues/10817
151152

152153

153-
def _record_river_error(error: RiverError) -> None:
154-
span = trace.get_current_span()
155-
span.record_exception(RiverException(error.code, error.message))
156-
span.set_attribute("river.error_code", error.code)
157-
span.set_attribute("river.error_message", error.message)
158-
159-
160154
@contextmanager
161155
def _trace_procedure(
162156
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
163157
service_name: str,
164158
procedure_name: str,
165-
) -> Generator[trace.Span, None, None]:
159+
) -> Generator[Span, None, None]:
166160
with tracer.start_span(
167161
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
168-
kind=trace.SpanKind.CLIENT,
162+
kind=SpanKind.CLIENT,
169163
) as span:
170164
try:
171165
yield span
172166
except RiverException as e:
173-
span.set_attribute("river.error_code", e.code)
174-
span.set_attribute("river.error_message", e.message)
167+
_record_river_error(span, RiverError(code=e.code, message=e.message))
175168
raise e
169+
170+
171+
def _record_river_error(span: Span, error: RiverError) -> None:
172+
span.set_status(StatusCode.ERROR, error.message)
173+
span.record_exception(RiverException(error.code, error.message))
174+
span.set_attribute("river.error_code", error.code)
175+
span.set_attribute("river.error_message", error.message)

replit_river/rpc.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -440,17 +440,16 @@ async def _convert_inputs() -> None:
440440
response = method(request, context)
441441

442442
async def _convert_outputs() -> None:
443-
try:
444-
async for item in response:
445-
await output.put(
446-
get_response_or_error_payload(item, response_serializer)
447-
)
448-
finally:
449-
output.close()
443+
async for item in response:
444+
await output.put(
445+
get_response_or_error_payload(item, response_serializer)
446+
)
450447

451448
convert_inputs_task = task_manager.create_task(_convert_inputs())
452449
convert_outputs_task = task_manager.create_task(_convert_outputs())
453-
await asyncio.wait((convert_inputs_task, convert_outputs_task))
450+
done, _ = await asyncio.wait((convert_inputs_task, convert_outputs_task))
451+
for task in done:
452+
await task
454453
except grpc.RpcError:
455454
logger.exception("RPC exception in stream")
456455
code = grpc.StatusCode(context._abort_code).name if context else "UNKNOWN"

tests/conftest.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
from collections.abc import AsyncIterator
44
from typing import Any, AsyncGenerator, Iterator, Literal
55

6+
import grpc.aio
67
import nanoid # type: ignore
78
import pytest
9+
from opentelemetry import trace
10+
from opentelemetry.sdk.trace import TracerProvider
11+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
12+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
813
from websockets.server import serve
914

1015
from replit_river.client import Client
1116
from replit_river.client_transport import UriAndMetadata
12-
from replit_river.error_schema import RiverError
17+
from replit_river.error_schema import RiverError, RiverException
1318
from replit_river.rpc import (
14-
GrpcContext,
1519
TransportMessage,
1620
rpc_method_handler,
1721
stream_method_handler,
@@ -68,12 +72,12 @@ def deserialize_error(response: dict) -> RiverError:
6872

6973

7074
# RPC method handlers for testing
71-
async def rpc_handler(request: str, context: GrpcContext) -> str:
75+
async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
7276
return f"Hello, {request}!"
7377

7478

7579
async def subscription_handler(
76-
request: str, context: GrpcContext
80+
request: str, context: grpc.aio.ServicerContext
7781
) -> AsyncGenerator[str, None]:
7882
for i in range(5):
7983
yield f"Subscription message {i} for {request}"
@@ -93,7 +97,8 @@ async def upload_handler(
9397

9498

9599
async def stream_handler(
96-
request: Iterator[str] | AsyncIterator[str], context: GrpcContext
100+
request: Iterator[str] | AsyncIterator[str],
101+
context: grpc.aio.ServicerContext,
97102
) -> AsyncGenerator[str, None]:
98103
if isinstance(request, AsyncIterator):
99104
async for data in request:
@@ -103,6 +108,14 @@ async def stream_handler(
103108
yield f"Stream response for {data}"
104109

105110

111+
async def stream_error_handler(
112+
request: Iterator[str] | AsyncIterator[str],
113+
context: grpc.aio.ServicerContext,
114+
) -> AsyncGenerator[str, None]:
115+
raise RiverException("INJECTED_ERROR", "test error")
116+
yield "test" # appease the type checker
117+
118+
106119
@pytest.fixture
107120
def transport_options() -> TransportOptions:
108121
return TransportOptions()
@@ -137,6 +150,12 @@ def server(transport_options: TransportOptions) -> Server:
137150
stream_handler, deserialize_request, serialize_response
138151
),
139152
),
153+
("test_service", "stream_method_error"): (
154+
"stream",
155+
stream_method_handler(
156+
stream_error_handler, deserialize_request, serialize_response
157+
),
158+
),
140159
}
141160
)
142161
return server
@@ -173,3 +192,18 @@ async def websocket_uri_factory() -> UriAndMetadata[None]:
173192
await server.close()
174193
# Server should close normally
175194
no_logging_error()
195+
196+
197+
@pytest.fixture(scope="session")
198+
def span_exporter() -> InMemorySpanExporter:
199+
exporter = InMemorySpanExporter()
200+
processor = SimpleSpanProcessor(exporter)
201+
provider = TracerProvider()
202+
provider.add_span_processor(processor)
203+
trace.set_tracer_provider(provider)
204+
return exporter
205+
206+
207+
@pytest.fixture(autouse=True)
208+
def reset_span_exporter(span_exporter: InMemorySpanExporter) -> None:
209+
span_exporter.clear()

tests/river_fixtures/logging.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ class NoErrors:
1515

1616
def __init__(self, caplog: LogCaptureFixture):
1717
self.caplog = caplog
18+
self._allow_errors = False
19+
20+
def allow_errors(self) -> None:
21+
self._allow_errors = True
1822

1923
def __call__(self) -> None:
24+
if self._allow_errors:
25+
return
26+
2027
assert len(self.caplog.get_records("setup")) == 0
2128
assert len(self.caplog.get_records("call")) == 0
2229
assert len(self.caplog.get_records("teardown")) == 0

tests/test_communication.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ async def test_rpc_method(client: Client) -> None:
1818
serialize_request,
1919
deserialize_response,
2020
deserialize_error,
21-
) # type: ignore
21+
)
2222
assert response == "Hello, Alice!"
2323

2424

tests/test_opentelemetry.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from typing import AsyncGenerator
2+
3+
import pytest
4+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
5+
from opentelemetry.trace import StatusCode
6+
7+
from replit_river.client import Client
8+
from replit_river.error_schema import RiverError
9+
from tests.conftest import deserialize_error, deserialize_response, serialize_request
10+
from tests.river_fixtures.logging import NoErrors
11+
12+
13+
@pytest.mark.asyncio
14+
async def test_rpc_method_span(
15+
client: Client, span_exporter: InMemorySpanExporter
16+
) -> None:
17+
response = await client.send_rpc(
18+
"test_service",
19+
"rpc_method",
20+
"Alice",
21+
serialize_request,
22+
deserialize_response,
23+
deserialize_error,
24+
)
25+
assert response == "Hello, Alice!"
26+
spans = span_exporter.get_finished_spans()
27+
assert len(spans) == 1
28+
assert spans[0].name == "river.client.rpc.test_service.rpc_method"
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_upload_method_span(
33+
client: Client, span_exporter: InMemorySpanExporter
34+
) -> None:
35+
async def upload_data() -> AsyncGenerator[str, None]:
36+
yield "Data 1"
37+
yield "Data 2"
38+
yield "Data 3"
39+
40+
response = await client.send_upload(
41+
"test_service",
42+
"upload_method",
43+
"Initial Data",
44+
upload_data(),
45+
serialize_request,
46+
serialize_request,
47+
deserialize_response,
48+
deserialize_error,
49+
)
50+
assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3"
51+
spans = span_exporter.get_finished_spans()
52+
assert len(spans) == 1
53+
assert spans[0].name == "river.client.upload.test_service.upload_method"
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_subscription_method_span(
58+
client: Client, span_exporter: InMemorySpanExporter
59+
) -> None:
60+
async for response in client.send_subscription(
61+
"test_service",
62+
"subscription_method",
63+
"Bob",
64+
serialize_request,
65+
deserialize_response,
66+
deserialize_error,
67+
):
68+
assert isinstance(response, str)
69+
assert "Subscription message" in response
70+
71+
spans = span_exporter.get_finished_spans()
72+
assert len(spans) == 1
73+
assert spans[0].name == "river.client.subscription.test_service.subscription_method"
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_stream_method_span(
78+
client: Client, span_exporter: InMemorySpanExporter
79+
) -> None:
80+
async def stream_data() -> AsyncGenerator[str, None]:
81+
yield "Stream 1"
82+
yield "Stream 2"
83+
yield "Stream 3"
84+
85+
responses = []
86+
async for response in client.send_stream(
87+
"test_service",
88+
"stream_method",
89+
"Initial Stream Data",
90+
stream_data(),
91+
serialize_request,
92+
serialize_request,
93+
deserialize_response,
94+
deserialize_error,
95+
):
96+
responses.append(response)
97+
98+
assert responses == [
99+
"Stream response for Initial Stream Data",
100+
"Stream response for Stream 1",
101+
"Stream response for Stream 2",
102+
"Stream response for Stream 3",
103+
]
104+
105+
spans = span_exporter.get_finished_spans()
106+
assert len(spans) == 1
107+
assert spans[0].name == "river.client.stream.test_service.stream_method"
108+
109+
110+
@pytest.mark.asyncio
111+
async def test_stream_error_method_span(
112+
client: Client,
113+
span_exporter: InMemorySpanExporter,
114+
no_logging_error: NoErrors,
115+
) -> None:
116+
# We are explicitly testing errors.
117+
no_logging_error.allow_errors()
118+
119+
async def stream_data() -> AsyncGenerator[str, None]:
120+
yield "Stream 1"
121+
yield "Stream 2"
122+
yield "Stream 3"
123+
124+
responses = []
125+
async for response in client.send_stream(
126+
"test_service",
127+
"stream_method_error",
128+
"Initial Stream Data",
129+
stream_data(),
130+
serialize_request,
131+
serialize_request,
132+
deserialize_response,
133+
deserialize_error,
134+
):
135+
responses.append(response)
136+
137+
assert len(responses) == 1
138+
assert isinstance(responses[0], RiverError)
139+
140+
spans = span_exporter.get_finished_spans()
141+
assert len(spans) == 1
142+
assert spans[0].name == "river.client.stream.test_service.stream_method_error"
143+
assert spans[0].status.status_code == StatusCode.ERROR

0 commit comments

Comments
 (0)