2020 Any ,
2121 Awaitable ,
2222 Callable ,
23+ Literal ,
2324 MutableSequence ,
25+ Union ,
26+ cast ,
27+ overload ,
2428)
2529
2630from opentelemetry ._events import EventLogger
31+ from opentelemetry .instrumentation ._semconv import (
32+ _StabilityMode ,
33+ )
2734from opentelemetry .instrumentation .vertexai .utils import (
2835 GenerateContentParams ,
36+ create_operation_details_event ,
2937 get_genai_request_attributes ,
3038 get_genai_response_attributes ,
3139 get_server_attributes ,
3442 response_to_events ,
3543)
3644from opentelemetry .trace import SpanKind , Tracer
45+ from opentelemetry .util .genai .types import ContentCapturingMode
3746
3847if TYPE_CHECKING :
3948 from google .cloud .aiplatform_v1 .services .prediction_service import client
@@ -89,17 +98,96 @@ def _extract_params(
8998 )
9099
91100
101+ # For details about GEN_AI_LATEST_EXPERIMENTAL stability mode see
102+ # https://github.com/open-telemetry/semantic-conventions/blob/v1.37.0/docs/gen-ai/gen-ai-agent-spans.md?plain=1#L18-L37
92103class MethodWrappers :
104+ @overload
105+ def __init__ (
106+ self ,
107+ tracer : Tracer ,
108+ event_logger : EventLogger ,
109+ capture_content : ContentCapturingMode ,
110+ sem_conv_opt_in_mode : Literal [
111+ _StabilityMode .GEN_AI_LATEST_EXPERIMENTAL
112+ ],
113+ ) -> None : ...
114+
115+ @overload
116+ def __init__ (
117+ self ,
118+ tracer : Tracer ,
119+ event_logger : EventLogger ,
120+ capture_content : bool ,
121+ sem_conv_opt_in_mode : Literal [_StabilityMode .DEFAULT ],
122+ ) -> None : ...
123+
93124 def __init__ (
94- self , tracer : Tracer , event_logger : EventLogger , capture_content : bool
125+ self ,
126+ tracer : Tracer ,
127+ event_logger : EventLogger ,
128+ capture_content : Union [bool , ContentCapturingMode ],
129+ sem_conv_opt_in_mode : Union [
130+ Literal [_StabilityMode .DEFAULT ],
131+ Literal [_StabilityMode .GEN_AI_LATEST_EXPERIMENTAL ],
132+ ],
95133 ) -> None :
96134 self .tracer = tracer
97135 self .event_logger = event_logger
98136 self .capture_content = capture_content
137+ self .sem_conv_opt_in_mode = sem_conv_opt_in_mode
138+
139+ @contextmanager
140+ def _with_new_instrumentation (
141+ self ,
142+ capture_content : ContentCapturingMode ,
143+ instance : client .PredictionServiceClient
144+ | client_v1beta1 .PredictionServiceClient ,
145+ args : Any ,
146+ kwargs : Any ,
147+ ):
148+ params = _extract_params (* args , ** kwargs )
149+ api_endpoint : str = instance .api_endpoint # type: ignore[reportUnknownMemberType]
150+ span_attributes = {
151+ ** get_genai_request_attributes (False , params ),
152+ ** get_server_attributes (api_endpoint ),
153+ }
154+
155+ span_name = get_span_name (span_attributes )
156+
157+ with self .tracer .start_as_current_span (
158+ name = span_name ,
159+ kind = SpanKind .CLIENT ,
160+ attributes = span_attributes ,
161+ ) as span :
162+
163+ def handle_response (
164+ response : prediction_service .GenerateContentResponse
165+ | prediction_service_v1beta1 .GenerateContentResponse
166+ | None ,
167+ ) -> None :
168+ if span .is_recording () and response :
169+ # When streaming, this is called multiple times so attributes would be
170+ # overwritten. In practice, it looks the API only returns the interesting
171+ # attributes on the last streamed response. However, I couldn't find
172+ # documentation for this and setting attributes shouldn't be too expensive.
173+ span .set_attributes (
174+ get_genai_response_attributes (response )
175+ )
176+ self .event_logger .emit (
177+ create_operation_details_event (
178+ api_endpoint = api_endpoint ,
179+ params = params ,
180+ capture_content = capture_content ,
181+ response = response ,
182+ )
183+ )
184+
185+ yield handle_response
99186
100187 @contextmanager
101- def _with_instrumentation (
188+ def _with_default_instrumentation (
102189 self ,
190+ capture_content : bool ,
103191 instance : client .PredictionServiceClient
104192 | client_v1beta1 .PredictionServiceClient ,
105193 args : Any ,
@@ -108,7 +196,7 @@ def _with_instrumentation(
108196 params = _extract_params (* args , ** kwargs )
109197 api_endpoint : str = instance .api_endpoint # type: ignore[reportUnknownMemberType]
110198 span_attributes = {
111- ** get_genai_request_attributes (params ),
199+ ** get_genai_request_attributes (False , params ),
112200 ** get_server_attributes (api_endpoint ),
113201 }
114202
@@ -120,7 +208,7 @@ def _with_instrumentation(
120208 attributes = span_attributes ,
121209 ) as span :
122210 for event in request_to_events (
123- params = params , capture_content = self . capture_content
211+ params = params , capture_content = capture_content
124212 ):
125213 self .event_logger .emit (event )
126214
@@ -141,7 +229,7 @@ def handle_response(
141229 )
142230
143231 for event in response_to_events (
144- response = response , capture_content = self . capture_content
232+ response = response , capture_content = capture_content
145233 ):
146234 self .event_logger .emit (event )
147235
@@ -162,12 +250,25 @@ def generate_content(
162250 prediction_service .GenerateContentResponse
163251 | prediction_service_v1beta1 .GenerateContentResponse
164252 ):
165- with self ._with_instrumentation (
166- instance , args , kwargs
167- ) as handle_response :
168- response = wrapped (* args , ** kwargs )
169- handle_response (response )
170- return response
253+ if self .sem_conv_opt_in_mode == _StabilityMode .DEFAULT :
254+ capture_content_bool = cast (bool , self .capture_content )
255+ with self ._with_default_instrumentation (
256+ capture_content_bool , instance , args , kwargs
257+ ) as handle_response :
258+ response = wrapped (* args , ** kwargs )
259+ handle_response (response )
260+ return response
261+ else :
262+ capture_content = cast (ContentCapturingMode , self .capture_content )
263+ with self ._with_new_instrumentation (
264+ capture_content , instance , args , kwargs
265+ ) as handle_response :
266+ response = None
267+ try :
268+ response = wrapped (* args , ** kwargs )
269+ return response
270+ finally :
271+ handle_response (response )
171272
172273 async def agenerate_content (
173274 self ,
@@ -186,9 +287,22 @@ async def agenerate_content(
186287 prediction_service .GenerateContentResponse
187288 | prediction_service_v1beta1 .GenerateContentResponse
188289 ):
189- with self ._with_instrumentation (
190- instance , args , kwargs
191- ) as handle_response :
192- response = await wrapped (* args , ** kwargs )
193- handle_response (response )
194- return response
290+ if self .sem_conv_opt_in_mode == _StabilityMode .DEFAULT :
291+ capture_content_bool = cast (bool , self .capture_content )
292+ with self ._with_default_instrumentation (
293+ capture_content_bool , instance , args , kwargs
294+ ) as handle_response :
295+ response = await wrapped (* args , ** kwargs )
296+ handle_response (response )
297+ return response
298+ else :
299+ capture_content = cast (ContentCapturingMode , self .capture_content )
300+ with self ._with_new_instrumentation (
301+ capture_content , instance , args , kwargs
302+ ) as handle_response :
303+ response = None
304+ try :
305+ response = await wrapped (* args , ** kwargs )
306+ return response
307+ finally :
308+ handle_response (response )
0 commit comments