Skip to content

Commit 855d03f

Browse files
committed
Don't record GeneratorExit errors in stream RPC spans
1 parent ead7ce3 commit 855d03f

File tree

3 files changed

+65
-15
lines changed

3 files changed

+65
-15
lines changed

replit_river/client.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
2-
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
2+
from asyncio import CancelledError
3+
from collections.abc import AsyncIterable, Awaitable, Callable
34
from contextlib import contextmanager
45
from datetime import timedelta
5-
from typing import Any, Generator, Generic, Literal, Optional, Union
6+
from typing import Any, AsyncGenerator, Generator, Generic, Literal, Optional, Union
67

78
from opentelemetry import trace
89
from opentelemetry.trace import Span, SpanKind, StatusCode
@@ -109,7 +110,7 @@ async def send_subscription(
109110
request_serializer: Callable[[RequestType], Any],
110111
response_deserializer: Callable[[Any], ResponseType],
111112
error_deserializer: Callable[[Any], ErrorType],
112-
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
113+
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
113114
with _trace_procedure("subscription", service_name, procedure_name) as span:
114115
session = await self._transport.get_or_create_session()
115116
async for msg in session.send_subscription(
@@ -135,7 +136,7 @@ async def send_stream(
135136
request_serializer: Callable[[RequestType], Any],
136137
response_deserializer: Callable[[Any], ResponseType],
137138
error_deserializer: Callable[[Any], ErrorType],
138-
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
139+
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
139140
with _trace_procedure("stream", service_name, procedure_name) as span:
140141
session = await self._transport.get_or_create_session()
141142
async for msg in session.send_stream(
@@ -160,15 +161,27 @@ def _trace_procedure(
160161
service_name: str,
161162
procedure_name: str,
162163
) -> Generator[Span, None, None]:
163-
with tracer.start_span(
164+
span = tracer.start_span(
164165
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
165166
kind=SpanKind.CLIENT,
166-
) as span:
167-
try:
168-
yield span
169-
except RiverException as e:
170-
_record_river_error(span, RiverError(code=e.code, message=e.message))
171-
raise e
167+
)
168+
set_status = False
169+
try:
170+
yield span
171+
except RiverException as e:
172+
span.record_exception(e)
173+
_record_river_error(span, RiverError(code=e.code, message=e.message))
174+
set_status = True
175+
raise e
176+
except (Exception, CancelledError) as e:
177+
span.record_exception(e)
178+
span.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
179+
set_status = True
180+
raise e
181+
finally:
182+
if not set_status:
183+
span.set_status(StatusCode.OK)
184+
span.end()
172185

173186

174187
def _record_river_error(span: Span, error: RiverError) -> None:

replit_river/client_session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import asyncio
22
import logging
3-
from collections.abc import AsyncIterable, AsyncIterator
3+
from collections.abc import AsyncIterable
44
from datetime import timedelta
5-
from typing import Any, Callable, Optional, Union
5+
from typing import Any, AsyncGenerator, Callable, Optional, Union
66

77
import nanoid # type: ignore
88
from aiochannel import Channel
@@ -194,7 +194,7 @@ async def send_subscription(
194194
response_deserializer: Callable[[Any], ResponseType],
195195
error_deserializer: Callable[[Any], ErrorType],
196196
span: Span,
197-
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
197+
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
198198
"""Sends a subscription request to the server.
199199
200200
Expects the input and output be messages that will be msgpacked.
@@ -246,7 +246,7 @@ async def send_stream(
246246
response_deserializer: Callable[[Any], ResponseType],
247247
error_deserializer: Callable[[Any], ErrorType],
248248
span: Span,
249-
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
249+
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
250250
"""Sends a subscription request to the server.
251251
252252
Expects the input and output be messages that will be msgpacked.

tests/test_opentelemetry.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
from datetime import timedelta
23
from typing import AsyncGenerator, AsyncIterator, Iterator
34

@@ -182,3 +183,39 @@ async def stream_data() -> AsyncGenerator[str, None]:
182183
assert len(spans) == 1
183184
assert spans[0].name == "river.client.stream.test_service.stream_method_error"
184185
assert spans[0].status.status_code == StatusCode.ERROR
186+
187+
188+
@pytest.mark.asyncio
189+
@pytest.mark.parametrize("handlers", [{**basic_stream}])
190+
async def test_stream_method_span_generator_exit_not_recorded(
191+
client: Client, span_exporter: InMemorySpanExporter
192+
) -> None:
193+
async def stream_data() -> AsyncGenerator[str, None]:
194+
yield "Stream 1"
195+
yield "Stream 2"
196+
yield "Stream 3"
197+
198+
responses = []
199+
stream = client.send_stream(
200+
"test_service",
201+
"stream_method",
202+
"Initial Stream Data",
203+
stream_data(),
204+
serialize_request,
205+
serialize_request,
206+
deserialize_response,
207+
deserialize_error,
208+
)
209+
async with contextlib.aclosing(stream) as generator:
210+
async for response in generator:
211+
responses.append(response)
212+
break
213+
214+
assert responses == [
215+
"Stream response for Initial Stream Data",
216+
]
217+
218+
spans = span_exporter.get_finished_spans()
219+
assert len(spans) == 1
220+
assert spans[0].name == "river.client.stream.test_service.stream_method"
221+
assert spans[0].status.status_code == StatusCode.OK

0 commit comments

Comments
 (0)