Skip to content

Commit edf77b9

Browse files
committed
Implement interceptors to support OTel tracing
1 parent 8b7034d commit edf77b9

File tree

6 files changed

+385
-43
lines changed

6 files changed

+385
-43
lines changed

pyproject.toml

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ classifiers = [
2121
]
2222
dependencies = [
2323
"pydantic==2.9.2",
24-
"aiochannel>=1.2.1",
25-
"black>=23.11,<25.0",
26-
"grpcio-tools>=1.59.3",
27-
"grpcio>=1.59.3",
28-
"msgpack-types>=0.3.0",
29-
"msgpack>=1.0.7",
30-
"nanoid>=2.0.0",
31-
"protobuf>=4.24.4",
32-
"pydantic-core>=2.20.1",
33-
"websockets>=12.0",
24+
"aiochannel>=1.2.1",
25+
"black>=23.11,<25.0",
26+
"grpcio-tools>=1.59.3",
27+
"grpcio>=1.59.3",
28+
"msgpack-types>=0.3.0",
29+
"msgpack>=1.0.7",
30+
"nanoid>=2.0.0",
31+
"protobuf>=4.24.4",
32+
"pydantic-core>=2.20.1",
33+
"websockets>=12.0",
34+
"opentelemetry-sdk>=1.28.1",
35+
"opentelemetry-api>=1.28.1",
3436
]
3537

3638
[tool.uv]

replit_river/client.py

Lines changed: 117 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
33
from typing import Any, Generic, Optional, Union
44

5+
from replit_river.client_interceptor import (
6+
ClientInterceptor,
7+
ClientRpcDetails,
8+
ClientStreamDetails,
9+
ClientSubscriptionDetails,
10+
ClientUploadDetails,
11+
)
512
from replit_river.client_transport import ClientTransport
613
from replit_river.transport_options import (
714
HandshakeMetadataType,
@@ -28,6 +35,7 @@ def __init__(
2835
client_id: str,
2936
server_id: str,
3037
transport_options: TransportOptions,
38+
interceptors: list[ClientInterceptor] = [],
3139
) -> None:
3240
self._client_id = client_id
3341
self._server_id = server_id
@@ -37,6 +45,7 @@ def __init__(
3745
server_id=server_id,
3846
transport_options=transport_options,
3947
)
48+
self._interceptors = interceptors
4049

4150
async def close(self) -> None:
4251
logger.info(f"river client {self._client_id} start closing")
@@ -56,13 +65,32 @@ async def send_rpc(
5665
error_deserializer: Callable[[Any], ErrorType],
5766
) -> ResponseType:
5867
session = await self._transport.get_or_create_session()
59-
return await session.send_rpc(
60-
service_name,
61-
procedure_name,
62-
request,
63-
request_serializer,
64-
response_deserializer,
65-
error_deserializer,
68+
69+
async def _run_interceptor(
70+
details: ClientRpcDetails,
71+
interceptors: list[ClientInterceptor],
72+
) -> ResponseType:
73+
if interceptors:
74+
return await interceptors[0].intercept_rpc(
75+
details, lambda details: _run_interceptor(details, interceptors[1:])
76+
)
77+
else:
78+
return await session.send_rpc(
79+
details.service_name,
80+
details.procedure_name,
81+
details.request,
82+
request_serializer,
83+
response_deserializer,
84+
error_deserializer,
85+
)
86+
87+
return await _run_interceptor(
88+
ClientRpcDetails(
89+
service_name=service_name,
90+
procedure_name=procedure_name,
91+
request=request,
92+
),
93+
self._interceptors,
6694
)
6795

6896
async def send_upload(
@@ -77,15 +105,34 @@ async def send_upload(
77105
error_deserializer: Callable[[Any], ErrorType],
78106
) -> ResponseType:
79107
session = await self._transport.get_or_create_session()
80-
return await session.send_upload(
81-
service_name,
82-
procedure_name,
83-
init,
84-
request,
85-
init_serializer,
86-
request_serializer,
87-
response_deserializer,
88-
error_deserializer,
108+
109+
async def _run_interceptor(
110+
details: ClientUploadDetails,
111+
interceptors: list[ClientInterceptor],
112+
) -> ResponseType:
113+
if interceptors:
114+
return await interceptors[0].intercept_upload(
115+
details, lambda details: _run_interceptor(details, interceptors[1:])
116+
)
117+
else:
118+
return await session.send_upload(
119+
service_name,
120+
procedure_name,
121+
init,
122+
request,
123+
init_serializer,
124+
request_serializer,
125+
response_deserializer,
126+
error_deserializer,
127+
)
128+
129+
return await _run_interceptor(
130+
ClientUploadDetails(
131+
service_name=service_name,
132+
procedure_name=procedure_name,
133+
init=init,
134+
),
135+
self._interceptors,
89136
)
90137

91138
async def send_subscription(
@@ -98,13 +145,32 @@ async def send_subscription(
98145
error_deserializer: Callable[[Any], ErrorType],
99146
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
100147
session = await self._transport.get_or_create_session()
101-
return session.send_subscription(
102-
service_name,
103-
procedure_name,
104-
request,
105-
request_serializer,
106-
response_deserializer,
107-
error_deserializer,
148+
149+
async def _run_interceptor(
150+
details: ClientSubscriptionDetails,
151+
interceptors: list[ClientInterceptor],
152+
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
153+
if interceptors:
154+
return await interceptors[0].intercept_subscription(
155+
details, lambda details: _run_interceptor(details, interceptors[1:])
156+
)
157+
else:
158+
return session.send_subscription(
159+
service_name,
160+
procedure_name,
161+
request,
162+
request_serializer,
163+
response_deserializer,
164+
error_deserializer,
165+
)
166+
167+
return await _run_interceptor(
168+
ClientSubscriptionDetails(
169+
service_name=service_name,
170+
procedure_name=procedure_name,
171+
request=request,
172+
),
173+
self._interceptors,
108174
)
109175

110176
async def send_stream(
@@ -119,13 +185,32 @@ async def send_stream(
119185
error_deserializer: Callable[[Any], ErrorType],
120186
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
121187
session = await self._transport.get_or_create_session()
122-
return session.send_stream(
123-
service_name,
124-
procedure_name,
125-
init,
126-
request,
127-
init_serializer,
128-
request_serializer,
129-
response_deserializer,
130-
error_deserializer,
188+
189+
async def _run_interceptor(
190+
details: ClientStreamDetails,
191+
interceptors: list[ClientInterceptor],
192+
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
193+
if interceptors:
194+
return await interceptors[0].intercept_stream(
195+
details, lambda details: _run_interceptor(details, interceptors[1:])
196+
)
197+
else:
198+
return session.send_stream(
199+
service_name,
200+
procedure_name,
201+
init,
202+
request,
203+
init_serializer,
204+
request_serializer,
205+
response_deserializer,
206+
error_deserializer,
207+
)
208+
209+
return await _run_interceptor(
210+
ClientStreamDetails(
211+
service_name=service_name,
212+
procedure_name=procedure_name,
213+
init=init,
214+
),
215+
self._interceptors,
131216
)

replit_river/client_interceptor.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, AsyncIterator, Awaitable, Callable, NamedTuple, Optional
3+
4+
5+
class ClientRpcDetails(NamedTuple):
6+
service_name: str
7+
procedure_name: str
8+
request: Any
9+
10+
11+
class ClientUploadDetails(NamedTuple):
12+
service_name: str
13+
procedure_name: str
14+
init: Optional[Any]
15+
16+
17+
class ClientSubscriptionDetails(NamedTuple):
18+
service_name: str
19+
procedure_name: str
20+
request: Any
21+
22+
23+
class ClientStreamDetails(NamedTuple):
24+
service_name: str
25+
procedure_name: str
26+
init: Optional[Any]
27+
28+
29+
class ClientInterceptor(ABC):
30+
@abstractmethod
31+
async def intercept_rpc(
32+
self,
33+
details: ClientRpcDetails,
34+
continuation: Callable[[ClientRpcDetails], Awaitable[Any]],
35+
) -> Any:
36+
"""
37+
TODO: docs
38+
"""
39+
40+
@abstractmethod
41+
async def intercept_upload(
42+
self,
43+
details: ClientUploadDetails,
44+
continuation: Callable[[ClientUploadDetails], Awaitable[Any]],
45+
) -> Any:
46+
"""
47+
TODO: docs
48+
"""
49+
50+
@abstractmethod
51+
async def intercept_subscription(
52+
self,
53+
details: ClientSubscriptionDetails,
54+
continuation: Callable[
55+
[ClientSubscriptionDetails],
56+
Awaitable[AsyncIterator[Any]],
57+
],
58+
) -> AsyncIterator[Any]:
59+
"""
60+
TODO: docs
61+
"""
62+
63+
@abstractmethod
64+
async def intercept_stream(
65+
self,
66+
details: ClientStreamDetails,
67+
continuation: Callable[
68+
[ClientStreamDetails],
69+
Awaitable[AsyncIterator[Any]],
70+
],
71+
) -> AsyncIterator[Any]:
72+
"""
73+
TODO: docs
74+
"""
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Any, Awaitable, Callable
2+
3+
from opentelemetry import trace
4+
5+
from replit_river.client_interceptor import (
6+
ClientInterceptor,
7+
ClientRpcDetails,
8+
ClientStreamDetails,
9+
ClientSubscriptionDetails,
10+
ClientUploadDetails,
11+
)
12+
from replit_river.error_schema import RiverException
13+
14+
tracer = trace.get_tracer(__name__)
15+
16+
17+
class OpenTelemetryClientInterceptor(ClientInterceptor):
18+
async def intercept_rpc(
19+
self,
20+
details: ClientRpcDetails,
21+
continuation: Callable[[ClientRpcDetails], Awaitable[Any]],
22+
) -> Any:
23+
with tracer.start_as_current_span(
24+
f"river.rpc.{details.service_name}.{details.procedure_name}"
25+
) as span:
26+
try:
27+
return await continuation(details)
28+
except RiverException as e:
29+
span.set_attribute("river.error_code", e.code)
30+
span.set_attribute("river.error_message", e.message)
31+
return e
32+
33+
async def intercept_upload(
34+
self,
35+
details: ClientUploadDetails,
36+
continuation: Callable[[ClientUploadDetails], Awaitable[Any]],
37+
) -> Any:
38+
with tracer.start_as_current_span(
39+
f"river.upload.{details.service_name}.{details.procedure_name}"
40+
) as span:
41+
try:
42+
return await continuation(details)
43+
except RiverException as e:
44+
span.set_attribute("river.error_code", e.code)
45+
span.set_attribute("river.error_message", e.message)
46+
return e
47+
48+
async def intercept_subscription(
49+
self,
50+
details: ClientSubscriptionDetails,
51+
continuation: Callable[[ClientSubscriptionDetails], Awaitable[Any]],
52+
) -> Any:
53+
with tracer.start_as_current_span(
54+
f"river.subscription.{details.service_name}.{details.procedure_name}"
55+
) as span:
56+
try:
57+
return await continuation(details)
58+
except RiverException as e:
59+
span.set_attribute("river.error_code", e.code)
60+
span.set_attribute("river.error_message", e.message)
61+
return e
62+
63+
async def intercept_stream(
64+
self,
65+
details: ClientStreamDetails,
66+
continuation: Callable[[ClientStreamDetails], Awaitable[Any]],
67+
) -> Any:
68+
with tracer.start_as_current_span(
69+
f"river.stream.{details.service_name}.{details.procedure_name}"
70+
) as span:
71+
try:
72+
return await continuation(details)
73+
except RiverException as e:
74+
span.set_attribute("river.error_code", e.code)
75+
span.set_attribute("river.error_message", e.message)
76+
return e

tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ async def client(
138138
transport_options: TransportOptions,
139139
no_logging_error: NoErrors,
140140
) -> AsyncGenerator[Client, None]:
141-
142141
async def websocket_uri_factory() -> UriAndMetadata[None]:
143142
return {
144143
"uri": "ws://localhost:8765",

0 commit comments

Comments
 (0)