Skip to content

Commit a2dfcb9

Browse files
committed
Fix contextvar issues with async generators
1 parent d659175 commit a2dfcb9

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

replit_river/client.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def send_rpc(
6060
response_deserializer: Callable[[Any], ResponseType],
6161
error_deserializer: Callable[[Any], ErrorType],
6262
) -> ResponseType:
63-
with _trace_procedure("rpc", service_name, procedure_name):
63+
with _trace_procedure("rpc", service_name, procedure_name) as span:
6464
session = await self._transport.get_or_create_session()
6565
return await session.send_rpc(
6666
service_name,
@@ -69,6 +69,7 @@ async def send_rpc(
6969
request_serializer,
7070
response_deserializer,
7171
error_deserializer,
72+
span,
7273
)
7374

7475
async def send_upload(
@@ -82,7 +83,7 @@ async def send_upload(
8283
response_deserializer: Callable[[Any], ResponseType],
8384
error_deserializer: Callable[[Any], ErrorType],
8485
) -> ResponseType:
85-
with _trace_procedure("upload", service_name, procedure_name):
86+
with _trace_procedure("upload", service_name, procedure_name) as span:
8687
session = await self._transport.get_or_create_session()
8788
return await session.send_upload(
8889
service_name,
@@ -93,6 +94,7 @@ async def send_upload(
9394
request_serializer,
9495
response_deserializer,
9596
error_deserializer,
97+
span,
9698
)
9799

98100
async def send_subscription(
@@ -104,7 +106,7 @@ async def send_subscription(
104106
response_deserializer: Callable[[Any], ResponseType],
105107
error_deserializer: Callable[[Any], ErrorType],
106108
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
107-
with _trace_procedure("subscription", service_name, procedure_name):
109+
with _trace_procedure("subscription", service_name, procedure_name) as span:
108110
session = await self._transport.get_or_create_session()
109111
async for msg in session.send_subscription(
110112
service_name,
@@ -113,6 +115,7 @@ async def send_subscription(
113115
request_serializer,
114116
response_deserializer,
115117
error_deserializer,
118+
span,
116119
):
117120
if isinstance(msg, RiverError):
118121
_record_river_error(msg)
@@ -129,7 +132,7 @@ async def send_stream(
129132
response_deserializer: Callable[[Any], ResponseType],
130133
error_deserializer: Callable[[Any], ErrorType],
131134
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
132-
with _trace_procedure("stream", service_name, procedure_name):
135+
with _trace_procedure("stream", service_name, procedure_name) as span:
133136
session = await self._transport.get_or_create_session()
134137
async for msg in session.send_stream(
135138
service_name,
@@ -140,6 +143,7 @@ async def send_stream(
140143
request_serializer,
141144
response_deserializer,
142145
error_deserializer,
146+
span,
143147
):
144148
if isinstance(msg, RiverError):
145149
_record_river_error(msg)
@@ -158,13 +162,13 @@ def _trace_procedure(
158162
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
159163
service_name: str,
160164
procedure_name: str,
161-
) -> Generator[None, None, None]:
162-
with tracer.start_as_current_span(
165+
) -> Generator[trace.Span, None, None]:
166+
with tracer.start_span(
163167
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
164168
kind=trace.SpanKind.CLIENT,
165169
) as span:
166170
try:
167-
yield
171+
yield span
168172
except RiverException as e:
169173
span.set_attribute("river.error_code", e.code)
170174
span.set_attribute("river.error_message", e.message)

replit_river/client_session.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import nanoid # type: ignore
66
from aiochannel import Channel
77
from aiochannel.errors import ChannelClosed
8+
from opentelemetry.trace import Span
89

910
from replit_river.error_schema import (
1011
ERROR_CODE_STREAM_CLOSED,
@@ -37,6 +38,7 @@ async def send_rpc(
3738
request_serializer: Callable[[RequestType], Any],
3839
response_deserializer: Callable[[Any], ResponseType],
3940
error_deserializer: Callable[[Any], ErrorType],
41+
span: Span,
4042
) -> ResponseType:
4143
"""Sends a single RPC request to the server.
4244
@@ -51,6 +53,7 @@ async def send_rpc(
5153
payload=request_serializer(request),
5254
service_name=service_name,
5355
procedure_name=procedure_name,
56+
span=span,
5457
)
5558
# Handle potential errors during communication
5659
try:
@@ -89,6 +92,7 @@ async def send_upload(
8992
request_serializer: Callable[[RequestType], Any],
9093
response_deserializer: Callable[[Any], ResponseType],
9194
error_deserializer: Callable[[Any], ErrorType],
95+
span: Span,
9296
) -> ResponseType:
9397
"""Sends an upload request to the server.
9498
@@ -107,6 +111,7 @@ async def send_upload(
107111
service_name=service_name,
108112
procedure_name=procedure_name,
109113
payload=init_serializer(init),
114+
span=span,
110115
)
111116
first_message = False
112117
# If this request is not closed and the session is killed, we should
@@ -122,6 +127,7 @@ async def send_upload(
122127
procedure_name=procedure_name,
123128
control_flags=control_flags,
124129
payload=request_serializer(item),
130+
span=span,
125131
)
126132
except Exception as e:
127133
raise RiverServiceException(
@@ -171,6 +177,7 @@ async def send_subscription(
171177
request_serializer: Callable[[RequestType], Any],
172178
response_deserializer: Callable[[Any], ResponseType],
173179
error_deserializer: Callable[[Any], ErrorType],
180+
span: Span,
174181
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
175182
"""Sends a subscription request to the server.
176183
@@ -185,6 +192,7 @@ async def send_subscription(
185192
stream_id=stream_id,
186193
control_flags=STREAM_OPEN_BIT,
187194
payload=request_serializer(request),
195+
span=span,
188196
)
189197

190198
# Handle potential errors during communication
@@ -221,6 +229,7 @@ async def send_stream(
221229
request_serializer: Callable[[RequestType], Any],
222230
response_deserializer: Callable[[Any], ResponseType],
223231
error_deserializer: Callable[[Any], ErrorType],
232+
span: Span,
224233
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
225234
"""Sends a subscription request to the server.
226235
@@ -239,6 +248,7 @@ async def send_stream(
239248
stream_id=stream_id,
240249
control_flags=STREAM_OPEN_BIT,
241250
payload=init_serializer(init),
251+
span=span,
242252
)
243253
else:
244254
# Get the very first message to open the stream
@@ -250,6 +260,7 @@ async def send_stream(
250260
stream_id=stream_id,
251261
control_flags=STREAM_OPEN_BIT,
252262
payload=request_serializer(first),
263+
span=span,
253264
)
254265

255266
except StopAsyncIteration:

replit_river/session.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import nanoid # type: ignore
77
import websockets
88
from aiochannel import Channel, ChannelClosed
9+
from opentelemetry.trace import Span, use_span
910
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
1011
from websockets.exceptions import ConnectionClosed
1112

@@ -365,6 +366,7 @@ async def send_message(
365366
control_flags: int = 0,
366367
service_name: str | None = None,
367368
procedure_name: str | None = None,
369+
span: Span | None = None,
368370
) -> None:
369371
"""Send serialized messages to the websockets."""
370372
# if the session is not active, we should not do anything
@@ -382,9 +384,11 @@ async def send_message(
382384
serviceName=service_name,
383385
procedureName=procedure_name,
384386
)
385-
TraceContextTextMapPropagator().inject(
386-
msg, None, TransportMessageTracingSetter()
387-
)
387+
if span:
388+
with use_span(span):
389+
TraceContextTextMapPropagator().inject(
390+
msg, None, TransportMessageTracingSetter()
391+
)
388392
try:
389393
# We need this lock to ensure the buffer order and message sending order
390394
# are the same.

0 commit comments

Comments
 (0)