Skip to content

Commit 88f4347

Browse files
authored
bugfix: include streaming procedure errors in traces (#121)
Why === Streaming response procedures don't raise errors like the other procedures, instead the errors are included in the `AsyncIterator`. Also fixes issues with using contextvars + async generators. What changed ============ - Check if the message from the async iterator is a `RiverError`, if so, record it on the span - Use `start_span` instead of `start_as_current_span`, the latter resets a contextvar in its `finally` clause which is invalid to do for async generators as the async generator's finalizers run in a different context - Thread the span through manually so we still propagate the tracing info Test plan ========= - Should see errors for failed streaming procedures - Logs about resetting contextvars in a different context should go away - Added some tests for the otel stuff to make sure the error handling works here
1 parent c78c653 commit 88f4347

File tree

8 files changed

+247
-32
lines changed

8 files changed

+247
-32
lines changed

replit_river/client.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from typing import Any, Generator, Generic, Literal, Optional, Union
55

66
from opentelemetry import trace
7+
from opentelemetry.trace import Span, SpanKind, StatusCode
78

89
from replit_river.client_transport import ClientTransport
9-
from replit_river.error_schema import RiverException
10+
from replit_river.error_schema import RiverError, RiverException
1011
from replit_river.transport_options import (
1112
HandshakeMetadataType,
1213
TransportOptions,
@@ -60,7 +61,7 @@ async def send_rpc(
6061
response_deserializer: Callable[[Any], ResponseType],
6162
error_deserializer: Callable[[Any], ErrorType],
6263
) -> ResponseType:
63-
with _trace_procedure("rpc", service_name, procedure_name):
64+
with _trace_procedure("rpc", service_name, procedure_name) as span:
6465
session = await self._transport.get_or_create_session()
6566
return await session.send_rpc(
6667
service_name,
@@ -69,6 +70,7 @@ async def send_rpc(
6970
request_serializer,
7071
response_deserializer,
7172
error_deserializer,
73+
span,
7274
)
7375

7476
async def send_upload(
@@ -82,7 +84,7 @@ async def send_upload(
8284
response_deserializer: Callable[[Any], ResponseType],
8385
error_deserializer: Callable[[Any], ErrorType],
8486
) -> ResponseType:
85-
with _trace_procedure("upload", service_name, procedure_name):
87+
with _trace_procedure("upload", service_name, procedure_name) as span:
8688
session = await self._transport.get_or_create_session()
8789
return await session.send_upload(
8890
service_name,
@@ -93,6 +95,7 @@ async def send_upload(
9395
request_serializer,
9496
response_deserializer,
9597
error_deserializer,
98+
span,
9699
)
97100

98101
async def send_subscription(
@@ -104,7 +107,7 @@ async def send_subscription(
104107
response_deserializer: Callable[[Any], ResponseType],
105108
error_deserializer: Callable[[Any], ErrorType],
106109
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
107-
with _trace_procedure("subscription", service_name, procedure_name):
110+
with _trace_procedure("subscription", service_name, procedure_name) as span:
108111
session = await self._transport.get_or_create_session()
109112
async for msg in session.send_subscription(
110113
service_name,
@@ -113,8 +116,11 @@ async def send_subscription(
113116
request_serializer,
114117
response_deserializer,
115118
error_deserializer,
119+
span,
116120
):
117-
yield msg
121+
if isinstance(msg, RiverError):
122+
_record_river_error(span, msg)
123+
yield msg # type: ignore # https://github.com/python/mypy/issues/10817
118124

119125
async def send_stream(
120126
self,
@@ -127,7 +133,7 @@ async def send_stream(
127133
response_deserializer: Callable[[Any], ResponseType],
128134
error_deserializer: Callable[[Any], ErrorType],
129135
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
130-
with _trace_procedure("stream", service_name, procedure_name):
136+
with _trace_procedure("stream", service_name, procedure_name) as span:
131137
session = await self._transport.get_or_create_session()
132138
async for msg in session.send_stream(
133139
service_name,
@@ -138,23 +144,32 @@ async def send_stream(
138144
request_serializer,
139145
response_deserializer,
140146
error_deserializer,
147+
span,
141148
):
142-
yield msg
149+
if isinstance(msg, RiverError):
150+
_record_river_error(span, msg)
151+
yield msg # type: ignore # https://github.com/python/mypy/issues/10817
143152

144153

145154
@contextmanager
146155
def _trace_procedure(
147156
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
148157
service_name: str,
149158
procedure_name: str,
150-
) -> Generator[None, None, None]:
151-
with tracer.start_as_current_span(
159+
) -> Generator[Span, None, None]:
160+
with tracer.start_span(
152161
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
153-
kind=trace.SpanKind.CLIENT,
162+
kind=SpanKind.CLIENT,
154163
) as span:
155164
try:
156-
yield
165+
yield span
157166
except RiverException as e:
158-
span.set_attribute("river.error_code", e.code)
159-
span.set_attribute("river.error_message", e.message)
167+
_record_river_error(span, RiverError(code=e.code, message=e.message))
160168
raise e
169+
170+
171+
def _record_river_error(span: Span, error: RiverError) -> None:
172+
span.set_status(StatusCode.ERROR, error.message)
173+
span.record_exception(RiverException(error.code, error.message))
174+
span.set_attribute("river.error_code", error.code)
175+
span.set_attribute("river.error_message", error.message)

replit_river/client_session.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import nanoid # type: ignore
66
from aiochannel import Channel
77
from aiochannel.errors import ChannelClosed
8+
from opentelemetry.trace import Span
89

910
from replit_river.error_schema import (
1011
ERROR_CODE_STREAM_CLOSED,
@@ -37,6 +38,7 @@ async def send_rpc(
3738
request_serializer: Callable[[RequestType], Any],
3839
response_deserializer: Callable[[Any], ResponseType],
3940
error_deserializer: Callable[[Any], ErrorType],
41+
span: Span,
4042
) -> ResponseType:
4143
"""Sends a single RPC request to the server.
4244
@@ -51,6 +53,7 @@ async def send_rpc(
5153
payload=request_serializer(request),
5254
service_name=service_name,
5355
procedure_name=procedure_name,
56+
span=span,
5457
)
5558
# Handle potential errors during communication
5659
try:
@@ -89,6 +92,7 @@ async def send_upload(
8992
request_serializer: Callable[[RequestType], Any],
9093
response_deserializer: Callable[[Any], ResponseType],
9194
error_deserializer: Callable[[Any], ErrorType],
95+
span: Span,
9296
) -> ResponseType:
9397
"""Sends an upload request to the server.
9498
@@ -107,6 +111,7 @@ async def send_upload(
107111
service_name=service_name,
108112
procedure_name=procedure_name,
109113
payload=init_serializer(init),
114+
span=span,
110115
)
111116
first_message = False
112117
# If this request is not closed and the session is killed, we should
@@ -122,6 +127,7 @@ async def send_upload(
122127
procedure_name=procedure_name,
123128
control_flags=control_flags,
124129
payload=request_serializer(item),
130+
span=span,
125131
)
126132
except Exception as e:
127133
raise RiverServiceException(
@@ -171,6 +177,7 @@ async def send_subscription(
171177
request_serializer: Callable[[RequestType], Any],
172178
response_deserializer: Callable[[Any], ResponseType],
173179
error_deserializer: Callable[[Any], ErrorType],
180+
span: Span,
174181
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
175182
"""Sends a subscription request to the server.
176183
@@ -185,6 +192,7 @@ async def send_subscription(
185192
stream_id=stream_id,
186193
control_flags=STREAM_OPEN_BIT,
187194
payload=request_serializer(request),
195+
span=span,
188196
)
189197

190198
# Handle potential errors during communication
@@ -221,6 +229,7 @@ async def send_stream(
221229
request_serializer: Callable[[RequestType], Any],
222230
response_deserializer: Callable[[Any], ResponseType],
223231
error_deserializer: Callable[[Any], ErrorType],
232+
span: Span,
224233
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
225234
"""Sends a subscription request to the server.
226235
@@ -239,6 +248,7 @@ async def send_stream(
239248
stream_id=stream_id,
240249
control_flags=STREAM_OPEN_BIT,
241250
payload=init_serializer(init),
251+
span=span,
242252
)
243253
else:
244254
# Get the very first message to open the stream
@@ -250,6 +260,7 @@ async def send_stream(
250260
stream_id=stream_id,
251261
control_flags=STREAM_OPEN_BIT,
252262
payload=request_serializer(first),
263+
span=span,
253264
)
254265

255266
except StopAsyncIteration:

replit_river/rpc.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,9 @@ async def _convert_outputs() -> None:
388388

389389
convert_inputs_task = task_manager.create_task(_convert_inputs())
390390
convert_outputs_task = task_manager.create_task(_convert_outputs())
391-
await asyncio.wait((convert_inputs_task, convert_outputs_task))
392-
391+
done, _ = await asyncio.wait((convert_inputs_task, convert_outputs_task))
392+
for task in done:
393+
await task
393394
except Exception as e:
394395
logger.exception("Uncaught exception in upload")
395396
await output.put(
@@ -440,17 +441,16 @@ async def _convert_inputs() -> None:
440441
response = method(request, context)
441442

442443
async def _convert_outputs() -> None:
443-
try:
444-
async for item in response:
445-
await output.put(
446-
get_response_or_error_payload(item, response_serializer)
447-
)
448-
finally:
449-
output.close()
444+
async for item in response:
445+
await output.put(
446+
get_response_or_error_payload(item, response_serializer)
447+
)
450448

451449
convert_inputs_task = task_manager.create_task(_convert_inputs())
452450
convert_outputs_task = task_manager.create_task(_convert_outputs())
453-
await asyncio.wait((convert_inputs_task, convert_outputs_task))
451+
done, _ = await asyncio.wait((convert_inputs_task, convert_outputs_task))
452+
for task in done:
453+
await task
454454
except grpc.RpcError:
455455
logger.exception("RPC exception in stream")
456456
code = grpc.StatusCode(context._abort_code).name if context else "UNKNOWN"

replit_river/session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import nanoid # type: ignore
77
import websockets
88
from aiochannel import Channel, ChannelClosed
9+
from opentelemetry.trace import Span, use_span
910
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
1011
from websockets.exceptions import ConnectionClosed
1112

@@ -37,6 +38,9 @@
3738

3839
logger = logging.getLogger(__name__)
3940

41+
trace_propagator = TraceContextTextMapPropagator()
42+
trace_setter = TransportMessageTracingSetter()
43+
4044

4145
class SessionState(enum.Enum):
4246
"""The state a session can be in.
@@ -365,6 +369,7 @@ async def send_message(
365369
control_flags: int = 0,
366370
service_name: str | None = None,
367371
procedure_name: str | None = None,
372+
span: Span | None = None,
368373
) -> None:
369374
"""Send serialized messages to the websockets."""
370375
# if the session is not active, we should not do anything
@@ -382,9 +387,9 @@ async def send_message(
382387
serviceName=service_name,
383388
procedureName=procedure_name,
384389
)
385-
TraceContextTextMapPropagator().inject(
386-
msg, None, TransportMessageTracingSetter()
387-
)
390+
if span:
391+
with use_span(span):
392+
trace_propagator.inject(msg, None, trace_setter)
388393
try:
389394
# We need this lock to ensure the buffer order and message sending order
390395
# are the same.

tests/conftest.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
from collections.abc import AsyncIterator
44
from typing import Any, AsyncGenerator, Iterator, Literal
55

6+
import grpc.aio
67
import nanoid # type: ignore
78
import pytest
9+
from opentelemetry import trace
10+
from opentelemetry.sdk.trace import TracerProvider
11+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
12+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
813
from websockets.server import serve
914

1015
from replit_river.client import Client
1116
from replit_river.client_transport import UriAndMetadata
12-
from replit_river.error_schema import RiverError
17+
from replit_river.error_schema import RiverError, RiverException
1318
from replit_river.rpc import (
14-
GrpcContext,
1519
TransportMessage,
1620
rpc_method_handler,
1721
stream_method_handler,
@@ -68,12 +72,12 @@ def deserialize_error(response: dict) -> RiverError:
6872

6973

7074
# RPC method handlers for testing
71-
async def rpc_handler(request: str, context: GrpcContext) -> str:
75+
async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
7276
return f"Hello, {request}!"
7377

7478

7579
async def subscription_handler(
76-
request: str, context: GrpcContext
80+
request: str, context: grpc.aio.ServicerContext
7781
) -> AsyncGenerator[str, None]:
7882
for i in range(5):
7983
yield f"Subscription message {i} for {request}"
@@ -93,7 +97,8 @@ async def upload_handler(
9397

9498

9599
async def stream_handler(
96-
request: Iterator[str] | AsyncIterator[str], context: GrpcContext
100+
request: Iterator[str] | AsyncIterator[str],
101+
context: grpc.aio.ServicerContext,
97102
) -> AsyncGenerator[str, None]:
98103
if isinstance(request, AsyncIterator):
99104
async for data in request:
@@ -103,6 +108,14 @@ async def stream_handler(
103108
yield f"Stream response for {data}"
104109

105110

111+
async def stream_error_handler(
112+
request: Iterator[str] | AsyncIterator[str],
113+
context: grpc.aio.ServicerContext,
114+
) -> AsyncGenerator[str, None]:
115+
raise RiverException("INJECTED_ERROR", "test error")
116+
yield "test" # appease the type checker
117+
118+
106119
@pytest.fixture
107120
def transport_options() -> TransportOptions:
108121
return TransportOptions()
@@ -137,6 +150,12 @@ def server(transport_options: TransportOptions) -> Server:
137150
stream_handler, deserialize_request, serialize_response
138151
),
139152
),
153+
("test_service", "stream_method_error"): (
154+
"stream",
155+
stream_method_handler(
156+
stream_error_handler, deserialize_request, serialize_response
157+
),
158+
),
140159
}
141160
)
142161
return server
@@ -173,3 +192,18 @@ async def websocket_uri_factory() -> UriAndMetadata[None]:
173192
await server.close()
174193
# Server should close normally
175194
no_logging_error()
195+
196+
197+
@pytest.fixture(scope="session")
198+
def span_exporter() -> InMemorySpanExporter:
199+
exporter = InMemorySpanExporter()
200+
processor = SimpleSpanProcessor(exporter)
201+
provider = TracerProvider()
202+
provider.add_span_processor(processor)
203+
trace.set_tracer_provider(provider)
204+
return exporter
205+
206+
207+
@pytest.fixture(autouse=True)
208+
def reset_span_exporter(span_exporter: InMemorySpanExporter) -> None:
209+
span_exporter.clear()

tests/river_fixtures/logging.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ class NoErrors:
1515

1616
def __init__(self, caplog: LogCaptureFixture):
1717
self.caplog = caplog
18+
self._allow_errors = False
19+
20+
def allow_errors(self) -> None:
21+
self._allow_errors = True
1822

1923
def __call__(self) -> None:
24+
if self._allow_errors:
25+
return
26+
2027
assert len(self.caplog.get_records("setup")) == 0
2128
assert len(self.caplog.get_records("call")) == 0
2229
assert len(self.caplog.get_records("teardown")) == 0

tests/test_communication.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ async def test_rpc_method(client: Client) -> None:
1818
serialize_request,
1919
deserialize_response,
2020
deserialize_error,
21-
) # type: ignore
21+
)
2222
assert response == "Hello, Alice!"
2323

2424

0 commit comments

Comments
 (0)