Skip to content

Commit 7717461

Browse files
committed
Add opentelemetry instrumentation
1 parent fbe67b7 commit 7717461

File tree

6 files changed

+206
-42
lines changed

6 files changed

+206
-42
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ dependencies = [
3131
"protobuf>=5.28.3",
3232
"pydantic-core>=2.20.1",
3333
"websockets>=12.0",
34+
"opentelemetry-sdk>=1.28.2",
35+
"opentelemetry-api>=1.28.2",
3436
]
3537

3638
[tool.uv]

replit_river/client.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import logging
22
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
3-
from typing import Any, Generic, Optional, Union
3+
from contextlib import contextmanager
4+
from typing import Any, Generator, Generic, Literal, Optional, Union
5+
6+
from opentelemetry import trace
47

58
from replit_river.client_transport import ClientTransport
9+
from replit_river.error_schema import RiverException
610
from replit_river.transport_options import (
711
HandshakeMetadataType,
812
TransportOptions,
@@ -17,6 +21,7 @@
1721
)
1822

1923
logger = logging.getLogger(__name__)
24+
tracer = trace.get_tracer(__name__)
2025

2126

2227
class Client(Generic[HandshakeMetadataType]):
@@ -55,15 +60,16 @@ async def send_rpc(
5560
response_deserializer: Callable[[Any], ResponseType],
5661
error_deserializer: Callable[[Any], ErrorType],
5762
) -> ResponseType:
58-
session = await self._transport.get_or_create_session()
59-
return await session.send_rpc(
60-
service_name,
61-
procedure_name,
62-
request,
63-
request_serializer,
64-
response_deserializer,
65-
error_deserializer,
66-
)
63+
with _trace_procedure("rpc", service_name, procedure_name):
64+
session = await self._transport.get_or_create_session()
65+
return await session.send_rpc(
66+
service_name,
67+
procedure_name,
68+
request,
69+
request_serializer,
70+
response_deserializer,
71+
error_deserializer,
72+
)
6773

6874
async def send_upload(
6975
self,
@@ -76,17 +82,18 @@ async def send_upload(
7682
response_deserializer: Callable[[Any], ResponseType],
7783
error_deserializer: Callable[[Any], ErrorType],
7884
) -> ResponseType:
79-
session = await self._transport.get_or_create_session()
80-
return await session.send_upload(
81-
service_name,
82-
procedure_name,
83-
init,
84-
request,
85-
init_serializer,
86-
request_serializer,
87-
response_deserializer,
88-
error_deserializer,
89-
)
85+
with _trace_procedure("upload", service_name, procedure_name):
86+
session = await self._transport.get_or_create_session()
87+
return await session.send_upload(
88+
service_name,
89+
procedure_name,
90+
init,
91+
request,
92+
init_serializer,
93+
request_serializer,
94+
response_deserializer,
95+
error_deserializer,
96+
)
9097

9198
async def send_subscription(
9299
self,
@@ -97,15 +104,16 @@ async def send_subscription(
97104
response_deserializer: Callable[[Any], ResponseType],
98105
error_deserializer: Callable[[Any], ErrorType],
99106
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
100-
session = await self._transport.get_or_create_session()
101-
return session.send_subscription(
102-
service_name,
103-
procedure_name,
104-
request,
105-
request_serializer,
106-
response_deserializer,
107-
error_deserializer,
108-
)
107+
with _trace_procedure("subscription", service_name, procedure_name):
108+
session = await self._transport.get_or_create_session()
109+
return session.send_subscription(
110+
service_name,
111+
procedure_name,
112+
request,
113+
request_serializer,
114+
response_deserializer,
115+
error_deserializer,
116+
)
109117

110118
async def send_stream(
111119
self,
@@ -118,14 +126,33 @@ async def send_stream(
118126
response_deserializer: Callable[[Any], ResponseType],
119127
error_deserializer: Callable[[Any], ErrorType],
120128
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
121-
session = await self._transport.get_or_create_session()
122-
return session.send_stream(
123-
service_name,
124-
procedure_name,
125-
init,
126-
request,
127-
init_serializer,
128-
request_serializer,
129-
response_deserializer,
130-
error_deserializer,
131-
)
129+
with _trace_procedure("stream", service_name, procedure_name):
130+
session = await self._transport.get_or_create_session()
131+
return session.send_stream(
132+
service_name,
133+
procedure_name,
134+
init,
135+
request,
136+
init_serializer,
137+
request_serializer,
138+
response_deserializer,
139+
error_deserializer,
140+
)
141+
142+
143+
@contextmanager
144+
def _trace_procedure(
145+
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
146+
service_name: str,
147+
procedure_name: str,
148+
) -> Generator[None, None, None]:
149+
with tracer.start_as_current_span(
150+
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
151+
kind=trace.SpanKind.CLIENT,
152+
) as span:
153+
try:
154+
yield
155+
except RiverException as e:
156+
span.set_attribute("river.error_code", e.code)
157+
span.set_attribute("river.error_message", e.message)
158+
raise e

replit_river/rpc.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import grpc
2424
from aiochannel import Channel, ChannelClosed
25+
from opentelemetry.propagators.textmap import Setter
2526
from pydantic import BaseModel, ConfigDict, Field
2627

2728
from replit_river.error_schema import (
@@ -86,6 +87,11 @@ class ControlMessageHandshakeResponse(BaseModel):
8687
status: HandShakeStatus
8788

8889

90+
class PropagationContext(BaseModel):
91+
traceparent: str = Field(default="")
92+
tracestate: str = Field(default="")
93+
94+
8995
class TransportMessage(BaseModel):
9096
id: str
9197
# from_ is used instead of from because from is a reserved keyword in Python
@@ -97,12 +103,30 @@ class TransportMessage(BaseModel):
97103
procedureName: Optional[str] = None
98104
streamId: str
99105
controlFlags: int
106+
tracing: Optional[PropagationContext] = None
100107
payload: Any
101108
model_config = ConfigDict(populate_by_name=True)
102109
# need this because we create TransportMessage objects with destructuring
103110
# where the key is "from"
104111

105112

113+
class TransportMessageTracingSetter(Setter[TransportMessage]):
114+
"""
115+
Handles propagating tracing context to the recipient of the message.
116+
"""
117+
118+
def set(self, carrier: TransportMessage, key: str, value: str) -> None:
119+
if not carrier.tracing:
120+
carrier.tracing = PropagationContext()
121+
match key:
122+
case "traceparent":
123+
carrier.tracing.traceparent = value
124+
case "tracestate":
125+
carrier.tracing.tracestate = value
126+
case _:
127+
logger.warning("unknown trace propagation key", extra={"key": key})
128+
129+
106130
class GrpcContext(grpc.aio.ServicerContext):
107131
"""Represents a gRPC-compatible ServicerContext for River interop."""
108132

replit_river/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import nanoid # type: ignore
77
import websockets
88
from aiochannel import Channel, ChannelClosed
9+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
910
from websockets.exceptions import ConnectionClosed
1011

1112
from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
@@ -31,6 +32,7 @@
3132
STREAM_OPEN_BIT,
3233
GenericRpcHandler,
3334
TransportMessage,
35+
TransportMessageTracingSetter,
3436
)
3537

3638
logger = logging.getLogger(__name__)
@@ -380,6 +382,9 @@ async def send_message(
380382
serviceName=service_name,
381383
procedureName=procedure_name,
382384
)
385+
TraceContextTextMapPropagator().inject(
386+
msg, None, TransportMessageTracingSetter()
387+
)
383388
try:
384389
# We need this lock to ensure the buffer order and message sending order
385390
# are the same.

scripts/lint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/bin/bash
1+
#!/usr/bin/env bash
22

33
set -ex
44

0 commit comments

Comments
 (0)