Skip to content

Commit eabb561

Browse files
committed
Implement interceptors to support OTel tracing
1 parent 8b7034d commit eabb561

File tree

6 files changed

+396
-43
lines changed

6 files changed

+396
-43
lines changed

pyproject.toml

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ classifiers = [
2121
]
2222
dependencies = [
2323
"pydantic==2.9.2",
24-
"aiochannel>=1.2.1",
25-
"black>=23.11,<25.0",
26-
"grpcio-tools>=1.59.3",
27-
"grpcio>=1.59.3",
28-
"msgpack-types>=0.3.0",
29-
"msgpack>=1.0.7",
30-
"nanoid>=2.0.0",
31-
"protobuf>=4.24.4",
32-
"pydantic-core>=2.20.1",
33-
"websockets>=12.0",
24+
"aiochannel>=1.2.1",
25+
"black>=23.11,<25.0",
26+
"grpcio-tools>=1.59.3",
27+
"grpcio>=1.59.3",
28+
"msgpack-types>=0.3.0",
29+
"msgpack>=1.0.7",
30+
"nanoid>=2.0.0",
31+
"protobuf>=4.24.4",
32+
"pydantic-core>=2.20.1",
33+
"websockets>=12.0",
34+
"opentelemetry-sdk>=1.27.0,<1.28.0",
35+
"opentelemetry-api>=1.27.0,<1.28.0",
3436
]
3537

3638
[tool.uv]

replit_river/client.py

Lines changed: 122 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
33
from typing import Any, Generic, Optional, Union
44

5+
from replit_river.client_interceptor import (
6+
ClientInterceptor,
7+
ClientRpcDetails,
8+
ClientStreamDetails,
9+
ClientSubscriptionDetails,
10+
ClientUploadDetails,
11+
)
512
from replit_river.client_transport import ClientTransport
613
from replit_river.transport_options import (
714
HandshakeMetadataType,
@@ -28,6 +35,7 @@ def __init__(
2835
client_id: str,
2936
server_id: str,
3037
transport_options: TransportOptions,
38+
interceptors: list[ClientInterceptor] = [],
3139
) -> None:
3240
self._client_id = client_id
3341
self._server_id = server_id
@@ -37,6 +45,7 @@ def __init__(
3745
server_id=server_id,
3846
transport_options=transport_options,
3947
)
48+
self._interceptors = interceptors
4049

4150
async def close(self) -> None:
4251
logger.info(f"river client {self._client_id} start closing")
@@ -56,13 +65,34 @@ async def send_rpc(
5665
error_deserializer: Callable[[Any], ErrorType],
5766
) -> ResponseType:
5867
session = await self._transport.get_or_create_session()
59-
return await session.send_rpc(
60-
service_name,
61-
procedure_name,
62-
request,
63-
request_serializer,
64-
response_deserializer,
65-
error_deserializer,
68+
69+
async def _run_interceptor(
70+
details: ClientRpcDetails,
71+
interceptors: list[ClientInterceptor],
72+
) -> ResponseType:
73+
if interceptors:
74+
head, tail = interceptors[0], interceptors[1:]
75+
return await head.intercept_rpc( # type: ignore
76+
details,
77+
lambda details: _run_interceptor(details, tail),
78+
)
79+
else:
80+
return await session.send_rpc(
81+
details.service_name,
82+
details.procedure_name,
83+
details.request,
84+
request_serializer,
85+
response_deserializer,
86+
error_deserializer,
87+
)
88+
89+
return await _run_interceptor(
90+
ClientRpcDetails(
91+
service_name=service_name,
92+
procedure_name=procedure_name,
93+
request=request,
94+
),
95+
self._interceptors,
6696
)
6797

6898
async def send_upload(
@@ -77,15 +107,35 @@ async def send_upload(
77107
error_deserializer: Callable[[Any], ErrorType],
78108
) -> ResponseType:
79109
session = await self._transport.get_or_create_session()
80-
return await session.send_upload(
81-
service_name,
82-
procedure_name,
83-
init,
84-
request,
85-
init_serializer,
86-
request_serializer,
87-
response_deserializer,
88-
error_deserializer,
110+
111+
async def _run_interceptor(
112+
details: ClientUploadDetails,
113+
interceptors: list[ClientInterceptor],
114+
) -> ResponseType:
115+
if interceptors:
116+
head, tail = interceptors[0], interceptors[1:]
117+
return await head.intercept_upload( # type: ignore
118+
details, lambda details: _run_interceptor(details, tail)
119+
)
120+
else:
121+
return await session.send_upload(
122+
service_name,
123+
procedure_name,
124+
init,
125+
request,
126+
init_serializer,
127+
request_serializer,
128+
response_deserializer,
129+
error_deserializer,
130+
)
131+
132+
return await _run_interceptor(
133+
ClientUploadDetails(
134+
service_name=service_name,
135+
procedure_name=procedure_name,
136+
init=init,
137+
),
138+
self._interceptors,
89139
)
90140

91141
async def send_subscription(
@@ -98,13 +148,33 @@ async def send_subscription(
98148
error_deserializer: Callable[[Any], ErrorType],
99149
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
100150
session = await self._transport.get_or_create_session()
101-
return session.send_subscription(
102-
service_name,
103-
procedure_name,
104-
request,
105-
request_serializer,
106-
response_deserializer,
107-
error_deserializer,
151+
152+
async def _run_interceptor(
153+
details: ClientSubscriptionDetails,
154+
interceptors: list[ClientInterceptor],
155+
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
156+
if interceptors:
157+
head, tail = interceptors[0], interceptors[1:]
158+
return await head.intercept_subscription( # type: ignore
159+
details, lambda details: _run_interceptor(details, tail)
160+
)
161+
else:
162+
return session.send_subscription(
163+
service_name,
164+
procedure_name,
165+
request,
166+
request_serializer,
167+
response_deserializer,
168+
error_deserializer,
169+
)
170+
171+
return await _run_interceptor(
172+
ClientSubscriptionDetails(
173+
service_name=service_name,
174+
procedure_name=procedure_name,
175+
request=request,
176+
),
177+
self._interceptors,
108178
)
109179

110180
async def send_stream(
@@ -119,13 +189,33 @@ async def send_stream(
119189
error_deserializer: Callable[[Any], ErrorType],
120190
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
121191
session = await self._transport.get_or_create_session()
122-
return session.send_stream(
123-
service_name,
124-
procedure_name,
125-
init,
126-
request,
127-
init_serializer,
128-
request_serializer,
129-
response_deserializer,
130-
error_deserializer,
192+
193+
async def _run_interceptor(
194+
details: ClientStreamDetails,
195+
interceptors: list[ClientInterceptor],
196+
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
197+
if interceptors:
198+
head, tail = interceptors[0], interceptors[1:]
199+
return await head.intercept_stream( # type: ignore
200+
details, lambda details: _run_interceptor(details, tail)
201+
)
202+
else:
203+
return session.send_stream(
204+
service_name,
205+
procedure_name,
206+
init,
207+
request,
208+
init_serializer,
209+
request_serializer,
210+
response_deserializer,
211+
error_deserializer,
212+
)
213+
214+
return await _run_interceptor(
215+
ClientStreamDetails(
216+
service_name=service_name,
217+
procedure_name=procedure_name,
218+
init=init,
219+
),
220+
self._interceptors,
131221
)

replit_river/client_interceptor.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, AsyncIterator, Awaitable, Callable, NamedTuple, Optional
3+
4+
5+
class ClientRpcDetails(NamedTuple):
6+
service_name: str
7+
procedure_name: str
8+
request: Any
9+
10+
11+
class ClientUploadDetails(NamedTuple):
12+
service_name: str
13+
procedure_name: str
14+
init: Optional[Any]
15+
16+
17+
class ClientSubscriptionDetails(NamedTuple):
18+
service_name: str
19+
procedure_name: str
20+
request: Any
21+
22+
23+
class ClientStreamDetails(NamedTuple):
24+
service_name: str
25+
procedure_name: str
26+
init: Optional[Any]
27+
28+
29+
class ClientInterceptor(ABC):
30+
@abstractmethod
31+
async def intercept_rpc(
32+
self,
33+
details: ClientRpcDetails,
34+
continuation: Callable[[ClientRpcDetails], Awaitable[Any]],
35+
) -> Any:
36+
"""
37+
TODO: docs
38+
"""
39+
40+
@abstractmethod
41+
async def intercept_upload(
42+
self,
43+
details: ClientUploadDetails,
44+
continuation: Callable[[ClientUploadDetails], Awaitable[Any]],
45+
) -> Any:
46+
"""
47+
TODO: docs
48+
"""
49+
50+
@abstractmethod
51+
async def intercept_subscription(
52+
self,
53+
details: ClientSubscriptionDetails,
54+
continuation: Callable[
55+
[ClientSubscriptionDetails],
56+
Awaitable[AsyncIterator[Any]],
57+
],
58+
) -> AsyncIterator[Any]:
59+
"""
60+
TODO: docs
61+
"""
62+
63+
@abstractmethod
64+
async def intercept_stream(
65+
self,
66+
details: ClientStreamDetails,
67+
continuation: Callable[
68+
[ClientStreamDetails],
69+
Awaitable[AsyncIterator[Any]],
70+
],
71+
) -> AsyncIterator[Any]:
72+
"""
73+
TODO: docs
74+
"""
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Any, AsyncIterator, Awaitable, Callable
2+
3+
from opentelemetry import trace
4+
5+
from replit_river.client_interceptor import (
6+
ClientInterceptor,
7+
ClientRpcDetails,
8+
ClientStreamDetails,
9+
ClientSubscriptionDetails,
10+
ClientUploadDetails,
11+
)
12+
from replit_river.error_schema import RiverException
13+
14+
tracer = trace.get_tracer(__name__)
15+
16+
17+
class OpenTelemetryClientInterceptor(ClientInterceptor):
18+
async def intercept_rpc(
19+
self,
20+
details: ClientRpcDetails,
21+
continuation: Callable[[ClientRpcDetails], Awaitable[Any]],
22+
) -> Any:
23+
with tracer.start_as_current_span(
24+
f"river.rpc.{details.service_name}.{details.procedure_name}"
25+
) as span:
26+
try:
27+
return await continuation(details)
28+
except RiverException as e:
29+
span.set_attribute("river.error_code", e.code)
30+
span.set_attribute("river.error_message", e.message)
31+
return e
32+
33+
async def intercept_upload(
34+
self,
35+
details: ClientUploadDetails,
36+
continuation: Callable[[ClientUploadDetails], Awaitable[Any]],
37+
) -> Any:
38+
with tracer.start_as_current_span(
39+
f"river.upload.{details.service_name}.{details.procedure_name}"
40+
) as span:
41+
try:
42+
return await continuation(details)
43+
except RiverException as e:
44+
span.set_attribute("river.error_code", e.code)
45+
span.set_attribute("river.error_message", e.message)
46+
return e
47+
48+
async def intercept_subscription(
49+
self,
50+
details: ClientSubscriptionDetails,
51+
continuation: Callable[
52+
[ClientSubscriptionDetails],
53+
Awaitable[AsyncIterator[Any]],
54+
],
55+
) -> Any:
56+
with tracer.start_as_current_span(
57+
f"river.subscription.{details.service_name}.{details.procedure_name}"
58+
) as span:
59+
try:
60+
return await continuation(details)
61+
except RiverException as e:
62+
span.set_attribute("river.error_code", e.code)
63+
span.set_attribute("river.error_message", e.message)
64+
return e
65+
66+
async def intercept_stream(
67+
self,
68+
details: ClientStreamDetails,
69+
continuation: Callable[
70+
[ClientStreamDetails],
71+
Awaitable[AsyncIterator[Any]],
72+
],
73+
) -> Any:
74+
with tracer.start_as_current_span(
75+
f"river.stream.{details.service_name}.{details.procedure_name}"
76+
) as span:
77+
try:
78+
return await continuation(details)
79+
except RiverException as e:
80+
span.set_attribute("river.error_code", e.code)
81+
span.set_attribute("river.error_message", e.message)
82+
return e

tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ async def client(
138138
transport_options: TransportOptions,
139139
no_logging_error: NoErrors,
140140
) -> AsyncGenerator[Client, None]:
141-
142141
async def websocket_uri_factory() -> UriAndMetadata[None]:
143142
return {
144143
"uri": "ws://localhost:8765",

0 commit comments

Comments
 (0)