Skip to content

Commit 177916d

Browse files
committed
Add opentelemetry instrumentation
1 parent 8b7034d commit 177916d

File tree

6 files changed

+216
-52
lines changed

6 files changed

+216
-52
lines changed

pyproject.toml

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ classifiers = [
2121
]
2222
dependencies = [
2323
"pydantic==2.9.2",
24-
"aiochannel>=1.2.1",
25-
"black>=23.11,<25.0",
26-
"grpcio-tools>=1.59.3",
27-
"grpcio>=1.59.3",
28-
"msgpack-types>=0.3.0",
29-
"msgpack>=1.0.7",
30-
"nanoid>=2.0.0",
31-
"protobuf>=4.24.4",
32-
"pydantic-core>=2.20.1",
33-
"websockets>=12.0",
24+
"aiochannel>=1.2.1",
25+
"black>=23.11,<25.0",
26+
"grpcio-tools>=1.59.3",
27+
"grpcio>=1.59.3",
28+
"msgpack-types>=0.3.0",
29+
"msgpack>=1.0.7",
30+
"nanoid>=2.0.0",
31+
"protobuf>=4.24.4",
32+
"pydantic-core>=2.20.1",
33+
"websockets>=12.0",
34+
"opentelemetry-sdk>=1.28.2",
35+
"opentelemetry-api>=1.28.2",
3436
]
3537

3638
[tool.uv]

replit_river/client.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import logging
22
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
3-
from typing import Any, Generic, Optional, Union
3+
from contextlib import contextmanager
4+
from typing import Any, Generator, Generic, Literal, Optional, Union
5+
6+
from opentelemetry import trace
47

58
from replit_river.client_transport import ClientTransport
9+
from replit_river.error_schema import RiverException
610
from replit_river.transport_options import (
711
HandshakeMetadataType,
812
TransportOptions,
@@ -17,6 +21,7 @@
1721
)
1822

1923
logger = logging.getLogger(__name__)
24+
tracer = trace.get_tracer(__name__)
2025

2126

2227
class Client(Generic[HandshakeMetadataType]):
@@ -55,15 +60,16 @@ async def send_rpc(
5560
response_deserializer: Callable[[Any], ResponseType],
5661
error_deserializer: Callable[[Any], ErrorType],
5762
) -> ResponseType:
58-
session = await self._transport.get_or_create_session()
59-
return await session.send_rpc(
60-
service_name,
61-
procedure_name,
62-
request,
63-
request_serializer,
64-
response_deserializer,
65-
error_deserializer,
66-
)
63+
with _trace_procedure("rpc", service_name, procedure_name):
64+
session = await self._transport.get_or_create_session()
65+
return await session.send_rpc(
66+
service_name,
67+
procedure_name,
68+
request,
69+
request_serializer,
70+
response_deserializer,
71+
error_deserializer,
72+
)
6773

6874
async def send_upload(
6975
self,
@@ -76,17 +82,18 @@ async def send_upload(
7682
response_deserializer: Callable[[Any], ResponseType],
7783
error_deserializer: Callable[[Any], ErrorType],
7884
) -> ResponseType:
79-
session = await self._transport.get_or_create_session()
80-
return await session.send_upload(
81-
service_name,
82-
procedure_name,
83-
init,
84-
request,
85-
init_serializer,
86-
request_serializer,
87-
response_deserializer,
88-
error_deserializer,
89-
)
85+
with _trace_procedure("upload", service_name, procedure_name):
86+
session = await self._transport.get_or_create_session()
87+
return await session.send_upload(
88+
service_name,
89+
procedure_name,
90+
init,
91+
request,
92+
init_serializer,
93+
request_serializer,
94+
response_deserializer,
95+
error_deserializer,
96+
)
9097

9198
async def send_subscription(
9299
self,
@@ -97,15 +104,16 @@ async def send_subscription(
97104
response_deserializer: Callable[[Any], ResponseType],
98105
error_deserializer: Callable[[Any], ErrorType],
99106
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
100-
session = await self._transport.get_or_create_session()
101-
return session.send_subscription(
102-
service_name,
103-
procedure_name,
104-
request,
105-
request_serializer,
106-
response_deserializer,
107-
error_deserializer,
108-
)
107+
with _trace_procedure("subscription", service_name, procedure_name):
108+
session = await self._transport.get_or_create_session()
109+
return session.send_subscription(
110+
service_name,
111+
procedure_name,
112+
request,
113+
request_serializer,
114+
response_deserializer,
115+
error_deserializer,
116+
)
109117

110118
async def send_stream(
111119
self,
@@ -118,14 +126,33 @@ async def send_stream(
118126
response_deserializer: Callable[[Any], ResponseType],
119127
error_deserializer: Callable[[Any], ErrorType],
120128
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
121-
session = await self._transport.get_or_create_session()
122-
return session.send_stream(
123-
service_name,
124-
procedure_name,
125-
init,
126-
request,
127-
init_serializer,
128-
request_serializer,
129-
response_deserializer,
130-
error_deserializer,
131-
)
129+
with _trace_procedure("stream", service_name, procedure_name):
130+
session = await self._transport.get_or_create_session()
131+
return session.send_stream(
132+
service_name,
133+
procedure_name,
134+
init,
135+
request,
136+
init_serializer,
137+
request_serializer,
138+
response_deserializer,
139+
error_deserializer,
140+
)
141+
142+
143+
@contextmanager
144+
def _trace_procedure(
145+
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
146+
service_name: str,
147+
procedure_name: str,
148+
) -> Generator[None, None, None]:
149+
with tracer.start_as_current_span(
150+
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
151+
kind=trace.SpanKind.CLIENT,
152+
) as span:
153+
try:
154+
yield
155+
except RiverException as e:
156+
span.set_attribute("river.error_code", e.code)
157+
span.set_attribute("river.error_message", e.message)
158+
raise e

replit_river/rpc.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import grpc
2323
from aiochannel import Channel, ChannelClosed
24+
from opentelemetry.propagators.textmap import Setter
2425
from pydantic import BaseModel, ConfigDict, Field
2526

2627
from replit_river.error_schema import (
@@ -85,6 +86,11 @@ class ControlMessageHandshakeResponse(BaseModel):
8586
status: HandShakeStatus
8687

8788

89+
class PropagationContext(BaseModel):
90+
traceparent: str
91+
tracestate: str
92+
93+
8894
class TransportMessage(BaseModel):
8995
id: str
9096
# from_ is used instead of from because from is a reserved keyword in Python
@@ -96,12 +102,30 @@ class TransportMessage(BaseModel):
96102
procedureName: Optional[str] = None
97103
streamId: str
98104
controlFlags: int
105+
tracing: Optional[PropagationContext] = None
99106
payload: Any
100107
model_config = ConfigDict(populate_by_name=True)
101108
# need this because we create TransportMessage objects with destructuring
102109
# where the key is "from"
103110

104111

112+
class TransportMessageTracingSetter(Setter[TransportMessage]):
113+
"""
114+
Handles propagating tracing context to the recipient of the message.
115+
"""
116+
117+
def set(self, carrier: TransportMessage, key: str, value: str) -> None:
118+
if not carrier.tracing:
119+
carrier.tracing = PropagationContext(traceparent="", tracestate="")
120+
match key:
121+
case "traceparent":
122+
carrier.tracing.traceparent = value
123+
case "tracestate":
124+
carrier.tracing.tracestate = value
125+
case _:
126+
logger.warning("unknown trace propagation key", extra={"key": key})
127+
128+
105129
class GrpcContext(grpc.aio.ServicerContext):
106130
"""Represents a gRPC-compatible ServicerContext for River interop."""
107131

replit_river/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import nanoid # type: ignore
77
import websockets
88
from aiochannel import Channel, ChannelClosed
9+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
910
from websockets.exceptions import ConnectionClosed
1011

1112
from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
@@ -31,6 +32,7 @@
3132
STREAM_OPEN_BIT,
3233
GenericRpcHandler,
3334
TransportMessage,
35+
TransportMessageTracingSetter,
3436
)
3537

3638
logger = logging.getLogger(__name__)
@@ -380,6 +382,9 @@ async def send_message(
380382
serviceName=service_name,
381383
procedureName=procedure_name,
382384
)
385+
TraceContextTextMapPropagator().inject(
386+
msg, None, TransportMessageTracingSetter()
387+
)
383388
try:
384389
# We need this lock to ensure the buffer order and message sending order
385390
# are the same.

scripts/lint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/bin/bash
1+
#!/usr/bin/env bash
22

33
set -ex
44

0 commit comments

Comments
 (0)