Skip to content

Commit 87c1210

Browse files
committed
Implement interceptors to support OTel tracing
1 parent 8b7034d commit 87c1210

File tree

6 files changed

+386
-43
lines changed

6 files changed

+386
-43
lines changed

pyproject.toml

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ classifiers = [
2121
]
2222
dependencies = [
2323
"pydantic==2.9.2",
24-
"aiochannel>=1.2.1",
25-
"black>=23.11,<25.0",
26-
"grpcio-tools>=1.59.3",
27-
"grpcio>=1.59.3",
28-
"msgpack-types>=0.3.0",
29-
"msgpack>=1.0.7",
30-
"nanoid>=2.0.0",
31-
"protobuf>=4.24.4",
32-
"pydantic-core>=2.20.1",
33-
"websockets>=12.0",
24+
"aiochannel>=1.2.1",
25+
"black>=23.11,<25.0",
26+
"grpcio-tools>=1.59.3",
27+
"grpcio>=1.59.3",
28+
"msgpack-types>=0.3.0",
29+
"msgpack>=1.0.7",
30+
"nanoid>=2.0.0",
31+
"protobuf>=4.24.4",
32+
"pydantic-core>=2.20.1",
33+
"websockets>=12.0",
34+
"opentelemetry-sdk>=1.28.1",
35+
"opentelemetry-api>=1.28.1",
3436
]
3537

3638
[tool.uv]

replit_river/client.py

Lines changed: 118 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
import functools
12
import logging
23
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
34
from typing import Any, Generic, Optional, Union
45

6+
from replit_river.client_interceptor import (
7+
ClientRpcDetails,
8+
ClientInterceptor,
9+
ClientStreamDetails,
10+
ClientSubscriptionDetails,
11+
ClientUploadDetails,
12+
)
513
from replit_river.client_transport import ClientTransport
614
from replit_river.transport_options import (
715
HandshakeMetadataType,
@@ -28,6 +36,7 @@ def __init__(
2836
client_id: str,
2937
server_id: str,
3038
transport_options: TransportOptions,
39+
interceptors: list[ClientInterceptor] = [],
3140
) -> None:
3241
self._client_id = client_id
3342
self._server_id = server_id
@@ -37,6 +46,7 @@ def __init__(
3746
server_id=server_id,
3847
transport_options=transport_options,
3948
)
49+
self._interceptors = interceptors
4050

4151
async def close(self) -> None:
4252
logger.info(f"river client {self._client_id} start closing")
@@ -56,13 +66,32 @@ async def send_rpc(
5666
error_deserializer: Callable[[Any], ErrorType],
5767
) -> ResponseType:
5868
session = await self._transport.get_or_create_session()
59-
return await session.send_rpc(
60-
service_name,
61-
procedure_name,
62-
request,
63-
request_serializer,
64-
response_deserializer,
65-
error_deserializer,
69+
70+
async def _run_interceptor(
71+
details: ClientRpcDetails,
72+
interceptors: list[ClientInterceptor],
73+
) -> ResponseType:
74+
if interceptors:
75+
return await interceptors[0].intercept_rpc(
76+
details, lambda details: _run_interceptor(details, interceptors[1:])
77+
)
78+
else:
79+
return await session.send_rpc(
80+
details.service_name,
81+
details.procedure_name,
82+
details.request,
83+
request_serializer,
84+
response_deserializer,
85+
error_deserializer,
86+
)
87+
88+
return await _run_interceptor(
89+
ClientRpcDetails(
90+
service_name=service_name,
91+
procedure_name=procedure_name,
92+
request=request,
93+
),
94+
self._interceptors,
6695
)
6796

6897
async def send_upload(
@@ -77,15 +106,34 @@ async def send_upload(
77106
error_deserializer: Callable[[Any], ErrorType],
78107
) -> ResponseType:
79108
session = await self._transport.get_or_create_session()
80-
return await session.send_upload(
81-
service_name,
82-
procedure_name,
83-
init,
84-
request,
85-
init_serializer,
86-
request_serializer,
87-
response_deserializer,
88-
error_deserializer,
109+
110+
async def _run_interceptor(
111+
details: ClientUploadDetails,
112+
interceptors: list[ClientInterceptor],
113+
) -> ResponseType:
114+
if interceptors:
115+
return await interceptors[0].intercept_upload(
116+
details, lambda details: _run_interceptor(details, interceptors[1:])
117+
)
118+
else:
119+
return await session.send_upload(
120+
service_name,
121+
procedure_name,
122+
init,
123+
request,
124+
init_serializer,
125+
request_serializer,
126+
response_deserializer,
127+
error_deserializer,
128+
)
129+
130+
return await _run_interceptor(
131+
ClientUploadDetails(
132+
service_name=service_name,
133+
procedure_name=procedure_name,
134+
init=init,
135+
),
136+
self._interceptors,
89137
)
90138

91139
async def send_subscription(
@@ -98,13 +146,32 @@ async def send_subscription(
98146
error_deserializer: Callable[[Any], ErrorType],
99147
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
100148
session = await self._transport.get_or_create_session()
101-
return session.send_subscription(
102-
service_name,
103-
procedure_name,
104-
request,
105-
request_serializer,
106-
response_deserializer,
107-
error_deserializer,
149+
150+
async def _run_interceptor(
151+
details: ClientSubscriptionDetails,
152+
interceptors: list[ClientInterceptor],
153+
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
154+
if interceptors:
155+
return await interceptors[0].intercept_subscription(
156+
details, lambda details: _run_interceptor(details, interceptors[1:])
157+
)
158+
else:
159+
return session.send_subscription(
160+
service_name,
161+
procedure_name,
162+
request,
163+
request_serializer,
164+
response_deserializer,
165+
error_deserializer,
166+
)
167+
168+
return await _run_interceptor(
169+
ClientSubscriptionDetails(
170+
service_name=service_name,
171+
procedure_name=procedure_name,
172+
request=request,
173+
),
174+
self._interceptors,
108175
)
109176

110177
async def send_stream(
@@ -119,13 +186,32 @@ async def send_stream(
119186
error_deserializer: Callable[[Any], ErrorType],
120187
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
121188
session = await self._transport.get_or_create_session()
122-
return session.send_stream(
123-
service_name,
124-
procedure_name,
125-
init,
126-
request,
127-
init_serializer,
128-
request_serializer,
129-
response_deserializer,
130-
error_deserializer,
189+
190+
async def _run_interceptor(
191+
details: ClientStreamDetails,
192+
interceptors: list[ClientInterceptor],
193+
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
194+
if interceptors:
195+
return await interceptors[0].intercept_stream(
196+
details, lambda details: _run_interceptor(details, interceptors[1:])
197+
)
198+
else:
199+
return session.send_stream(
200+
service_name,
201+
procedure_name,
202+
init,
203+
request,
204+
init_serializer,
205+
request_serializer,
206+
response_deserializer,
207+
error_deserializer,
208+
)
209+
210+
return await _run_interceptor(
211+
ClientStreamDetails(
212+
service_name=service_name,
213+
procedure_name=procedure_name,
214+
init=init,
215+
),
216+
self._interceptors,
131217
)

replit_river/client_interceptor.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, AsyncIterator, Awaitable, Callable, NamedTuple, Optional
3+
4+
5+
class ClientRpcDetails(NamedTuple):
6+
service_name: str
7+
procedure_name: str
8+
request: Any
9+
10+
11+
class ClientUploadDetails(NamedTuple):
12+
service_name: str
13+
procedure_name: str
14+
init: Optional[Any]
15+
16+
17+
class ClientSubscriptionDetails(NamedTuple):
18+
service_name: str
19+
procedure_name: str
20+
request: Any
21+
22+
23+
class ClientStreamDetails(NamedTuple):
24+
service_name: str
25+
procedure_name: str
26+
init: Optional[Any]
27+
28+
29+
class ClientInterceptor(ABC):
30+
@abstractmethod
31+
async def intercept_rpc(
32+
self,
33+
details: ClientRpcDetails,
34+
continuation: Callable[[ClientRpcDetails], Awaitable[Any]],
35+
) -> Any:
36+
"""
37+
TODO: docs
38+
"""
39+
40+
@abstractmethod
41+
async def intercept_upload(
42+
self,
43+
details: ClientUploadDetails,
44+
continuation: Callable[[ClientUploadDetails], Awaitable[Any]],
45+
) -> Any:
46+
"""
47+
TODO: docs
48+
"""
49+
50+
@abstractmethod
51+
async def intercept_subscription(
52+
self,
53+
details: ClientSubscriptionDetails,
54+
continuation: Callable[
55+
[ClientSubscriptionDetails],
56+
Awaitable[AsyncIterator[Any]],
57+
],
58+
) -> AsyncIterator[Any]:
59+
"""
60+
TODO: docs
61+
"""
62+
63+
@abstractmethod
64+
async def intercept_stream(
65+
self,
66+
details: ClientStreamDetails,
67+
continuation: Callable[
68+
[ClientStreamDetails],
69+
Awaitable[AsyncIterator[Any]],
70+
],
71+
) -> AsyncIterator[Any]:
72+
"""
73+
TODO: docs
74+
"""
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Any, Awaitable, Callable
2+
3+
from opentelemetry import trace
4+
5+
from replit_river.client_interceptor import (
6+
ClientInterceptor,
7+
ClientRpcDetails,
8+
ClientStreamDetails,
9+
ClientSubscriptionDetails,
10+
ClientUploadDetails,
11+
)
12+
from replit_river.error_schema import RiverException
13+
14+
tracer = trace.get_tracer(__name__)
15+
16+
17+
class OpenTelemetryClientInterceptor(ClientInterceptor):
18+
async def intercept_rpc(
19+
self,
20+
details: ClientRpcDetails,
21+
continuation: Callable[[ClientRpcDetails], Awaitable[Any]],
22+
) -> Any:
23+
with tracer.start_as_current_span(
24+
f"river.rpc.{details.service_name}.{details.procedure_name}"
25+
) as span:
26+
try:
27+
return await continuation(details)
28+
except RiverException as e:
29+
span.set_attribute("river.error_code", e.code)
30+
span.set_attribute("river.error_message", e.message)
31+
return e
32+
33+
async def intercept_upload(
34+
self,
35+
details: ClientUploadDetails,
36+
continuation: Callable[[ClientUploadDetails], Awaitable[Any]],
37+
) -> Any:
38+
with tracer.start_as_current_span(
39+
f"river.upload.{details.service_name}.{details.procedure_name}"
40+
) as span:
41+
try:
42+
return await continuation(details)
43+
except RiverException as e:
44+
span.set_attribute("river.error_code", e.code)
45+
span.set_attribute("river.error_message", e.message)
46+
return e
47+
48+
async def intercept_subscription(
49+
self,
50+
details: ClientSubscriptionDetails,
51+
continuation: Callable[[ClientSubscriptionDetails], Awaitable[Any]],
52+
) -> Any:
53+
with tracer.start_as_current_span(
54+
f"river.subscription.{details.service_name}.{details.procedure_name}"
55+
) as span:
56+
try:
57+
return await continuation(details)
58+
except RiverException as e:
59+
span.set_attribute("river.error_code", e.code)
60+
span.set_attribute("river.error_message", e.message)
61+
return e
62+
63+
async def intercept_stream(
64+
self,
65+
details: ClientStreamDetails,
66+
continuation: Callable[[ClientStreamDetails], Awaitable[Any]],
67+
) -> Any:
68+
with tracer.start_as_current_span(
69+
f"river.stream.{details.service_name}.{details.procedure_name}"
70+
) as span:
71+
try:
72+
return await continuation(details)
73+
except RiverException as e:
74+
span.set_attribute("river.error_code", e.code)
75+
span.set_attribute("river.error_message", e.message)
76+
return e

tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ async def client(
138138
transport_options: TransportOptions,
139139
no_logging_error: NoErrors,
140140
) -> AsyncGenerator[Client, None]:
141-
142141
async def websocket_uri_factory() -> UriAndMetadata[None]:
143142
return {
144143
"uri": "ws://localhost:8765",

0 commit comments

Comments
 (0)