11import logging
22from collections .abc import AsyncIterable , AsyncIterator , Awaitable , Callable
3- from typing import Any , Generic , Optional , Union
3+ from contextlib import contextmanager
4+ from typing import Any , Generator , Generic , Literal , Optional , Union
5+
6+ from opentelemetry import trace
47
58from replit_river .client_transport import ClientTransport
9+ from replit_river .error_schema import RiverException
610from replit_river .transport_options import (
711 HandshakeMetadataType ,
812 TransportOptions ,
1721)
1822
1923logger = logging .getLogger (__name__ )
24+ tracer = trace .get_tracer (__name__ )
2025
2126
2227class Client (Generic [HandshakeMetadataType ]):
@@ -55,15 +60,16 @@ async def send_rpc(
5560 response_deserializer : Callable [[Any ], ResponseType ],
5661 error_deserializer : Callable [[Any ], ErrorType ],
5762 ) -> ResponseType :
58- session = await self ._transport .get_or_create_session ()
59- return await session .send_rpc (
60- service_name ,
61- procedure_name ,
62- request ,
63- request_serializer ,
64- response_deserializer ,
65- error_deserializer ,
66- )
63+ with _trace_procedure ("rpc" , service_name , procedure_name ):
64+ session = await self ._transport .get_or_create_session ()
65+ return await session .send_rpc (
66+ service_name ,
67+ procedure_name ,
68+ request ,
69+ request_serializer ,
70+ response_deserializer ,
71+ error_deserializer ,
72+ )
6773
6874 async def send_upload (
6975 self ,
@@ -76,17 +82,18 @@ async def send_upload(
7682 response_deserializer : Callable [[Any ], ResponseType ],
7783 error_deserializer : Callable [[Any ], ErrorType ],
7884 ) -> ResponseType :
79- session = await self ._transport .get_or_create_session ()
80- return await session .send_upload (
81- service_name ,
82- procedure_name ,
83- init ,
84- request ,
85- init_serializer ,
86- request_serializer ,
87- response_deserializer ,
88- error_deserializer ,
89- )
85+ with _trace_procedure ("upload" , service_name , procedure_name ):
86+ session = await self ._transport .get_or_create_session ()
87+ return await session .send_upload (
88+ service_name ,
89+ procedure_name ,
90+ init ,
91+ request ,
92+ init_serializer ,
93+ request_serializer ,
94+ response_deserializer ,
95+ error_deserializer ,
96+ )
9097
9198 async def send_subscription (
9299 self ,
@@ -97,15 +104,16 @@ async def send_subscription(
97104 response_deserializer : Callable [[Any ], ResponseType ],
98105 error_deserializer : Callable [[Any ], ErrorType ],
99106 ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
100- session = await self ._transport .get_or_create_session ()
101- return session .send_subscription (
102- service_name ,
103- procedure_name ,
104- request ,
105- request_serializer ,
106- response_deserializer ,
107- error_deserializer ,
108- )
107+ with _trace_procedure ("subscription" , service_name , procedure_name ):
108+ session = await self ._transport .get_or_create_session ()
109+ return session .send_subscription (
110+ service_name ,
111+ procedure_name ,
112+ request ,
113+ request_serializer ,
114+ response_deserializer ,
115+ error_deserializer ,
116+ )
109117
110118 async def send_stream (
111119 self ,
@@ -118,14 +126,33 @@ async def send_stream(
118126 response_deserializer : Callable [[Any ], ResponseType ],
119127 error_deserializer : Callable [[Any ], ErrorType ],
120128 ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
121- session = await self ._transport .get_or_create_session ()
122- return session .send_stream (
123- service_name ,
124- procedure_name ,
125- init ,
126- request ,
127- init_serializer ,
128- request_serializer ,
129- response_deserializer ,
130- error_deserializer ,
131- )
129+ with _trace_procedure ("stream" , service_name , procedure_name ):
130+ session = await self ._transport .get_or_create_session ()
131+ return session .send_stream (
132+ service_name ,
133+ procedure_name ,
134+ init ,
135+ request ,
136+ init_serializer ,
137+ request_serializer ,
138+ response_deserializer ,
139+ error_deserializer ,
140+ )
141+
142+
143+ @contextmanager
144+ def _trace_procedure (
145+ procedure_type : Literal ["rpc" , "upload" , "subscription" , "stream" ],
146+ service_name : str ,
147+ procedure_name : str ,
148+ ) -> Generator [None , None , None ]:
149+ with tracer .start_as_current_span (
150+ f"river.client.{ procedure_type } .{ service_name } .{ procedure_name } " ,
151+ kind = trace .SpanKind .CLIENT ,
152+ ) as span :
153+ try :
154+ yield
155+ except RiverException as e :
156+ span .set_attribute ("river.error_code" , e .code )
157+ span .set_attribute ("river.error_message" , e .message )
158+ raise e
0 commit comments