1111 emit_response_events ,
1212)
1313from opentelemetry .instrumentation .cohere .span_utils import (
14- set_input_attributes ,
15- set_response_attributes ,
14+ set_input_content_attributes ,
15+ set_response_content_attributes ,
1616 set_span_request_attributes ,
17+ set_span_response_attributes ,
18+ )
19+ from opentelemetry .instrumentation .cohere .streaming import (
20+ process_chat_v1_streaming_response ,
21+ aprocess_chat_v1_streaming_response ,
22+ process_chat_v2_streaming_response ,
23+ aprocess_chat_v2_streaming_response ,
1724)
1825from opentelemetry .instrumentation .cohere .utils import dont_throw , should_emit_events
1926from opentelemetry .instrumentation .cohere .version import __version__
2734 LLMRequestTypeValues ,
2835 SpanAttributes ,
2936)
30- from opentelemetry .trace import SpanKind , Tracer , get_tracer
37+ from opentelemetry .trace import SpanKind , Status , StatusCode , Tracer , get_tracer , use_span
3138from wrapt import wrap_function_wrapper
3239
3340logger = logging .getLogger (__name__ )
3643
3744WRAPPED_METHODS = [
3845 {
46+ "module" : "cohere.client" ,
3947 "object" : "Client" ,
4048 "method" : "generate" ,
4149 "span_name" : "cohere.completion" ,
4250 },
4351 {
52+ "module" : "cohere.client" ,
4453 "object" : "Client" ,
4554 "method" : "chat" ,
4655 "span_name" : "cohere.chat" ,
4756 },
4857 {
58+ "module" : "cohere.client" ,
59+ "object" : "Client" ,
60+ "method" : "chat_stream" ,
61+ "span_name" : "cohere.chat" ,
62+ "stream_process_func" : process_chat_v1_streaming_response ,
63+ },
64+ {
65+ "module" : "cohere.client" ,
66+ "object" : "Client" ,
67+ "method" : "rerank" ,
68+ "span_name" : "cohere.rerank" ,
69+ },
70+ {
71+ "module" : "cohere.client" ,
4972 "object" : "Client" ,
73+ "method" : "embed" ,
74+ "span_name" : "cohere.embed" ,
75+ },
76+ {
77+ "module" : "cohere.client_v2" ,
78+ "object" : "ClientV2" ,
79+ "method" : "chat" ,
80+ "span_name" : "cohere.chat" ,
81+ },
82+ {
83+ "module" : "cohere.client_v2" ,
84+ "object" : "ClientV2" ,
85+ "method" : "chat_stream" ,
86+ "span_name" : "cohere.chat" ,
87+ "stream_process_func" : process_chat_v2_streaming_response ,
88+ },
89+ {
90+ "module" : "cohere.client_v2" ,
91+ "object" : "ClientV2" ,
5092 "method" : "rerank" ,
5193 "span_name" : "cohere.rerank" ,
5294 },
95+ {
96+ "module" : "cohere.client_v2" ,
97+ "object" : "ClientV2" ,
98+ "method" : "embed" ,
99+ "span_name" : "cohere.embed" ,
100+ },
101+ # Async methods that return AsyncIterator must be wrapped with sync wrapper
102+ {
103+ "module" : "cohere.client" ,
104+ "object" : "AsyncClient" ,
105+ "method" : "chat_stream" ,
106+ "span_name" : "cohere.chat" ,
107+ "stream_process_func" : aprocess_chat_v1_streaming_response ,
108+ },
109+ {
110+ "module" : "cohere.client_v2" ,
111+ "object" : "AsyncClientV2" ,
112+ "method" : "chat_stream" ,
113+ "span_name" : "cohere.chat" ,
114+ "stream_process_func" : aprocess_chat_v2_streaming_response ,
115+ },
116+ ]
117+
118+ WRAPPED_AMETHODS = [
119+ {
120+ "module" : "cohere.client" ,
121+ "object" : "AsyncClient" ,
122+ "method" : "generate" ,
123+ "span_name" : "cohere.completion" ,
124+ },
125+ {
126+ "module" : "cohere.client" ,
127+ "object" : "AsyncClient" ,
128+ "method" : "chat" ,
129+ "span_name" : "cohere.chat" ,
130+ },
131+ {
132+ "module" : "cohere.client" ,
133+ "object" : "AsyncClient" ,
134+ "method" : "rerank" ,
135+ "span_name" : "cohere.rerank" ,
136+ },
137+ {
138+ "module" : "cohere.client" ,
139+ "object" : "AsyncClient" ,
140+ "method" : "embed" ,
141+ "span_name" : "cohere.embed" ,
142+ },
143+ {
144+ "module" : "cohere.client_v2" ,
145+ "object" : "AsyncClientV2" ,
146+ "method" : "chat" ,
147+ "span_name" : "cohere.chat" ,
148+ },
149+ {
150+ "module" : "cohere.client_v2" ,
151+ "object" : "AsyncClientV2" ,
152+ "method" : "rerank" ,
153+ "span_name" : "cohere.rerank" ,
154+ },
155+ {
156+ "module" : "cohere.client_v2" ,
157+ "object" : "AsyncClientV2" ,
158+ "method" : "embed" ,
159+ "span_name" : "cohere.embed" ,
160+ },
53161]
54162
55163
@@ -66,30 +174,30 @@ def wrapper(wrapped, instance, args, kwargs):
66174
67175
68176def _llm_request_type_by_method (method_name ):
69- if method_name == "chat" :
177+ if method_name in [ "chat" , "chat_stream" ] :
70178 return LLMRequestTypeValues .CHAT
71- elif method_name == "generate" :
179+ elif method_name in [ "generate" , "generate_stream" ] :
72180 return LLMRequestTypeValues .COMPLETION
73181 elif method_name == "rerank" :
74182 return LLMRequestTypeValues .RERANK
183+ elif method_name == "embed" :
184+ return LLMRequestTypeValues .EMBEDDING
75185 else :
76186 return LLMRequestTypeValues .UNKNOWN
77187
78188
79189@dont_throw
80- def _handle_input (span , event_logger , llm_request_type , kwargs ):
190+ def _handle_input_content (span , event_logger , llm_request_type , kwargs ):
191+ set_input_content_attributes (span , llm_request_type , kwargs )
81192 if should_emit_events ():
82193 emit_input_event (event_logger , llm_request_type , kwargs )
83- else :
84- set_input_attributes (span , llm_request_type , kwargs )
85194
86195
87196@dont_throw
88- def _handle_response (span , event_logger , llm_request_type , response ):
197+ def _handle_response_content (span , event_logger , llm_request_type , response ):
198+ set_response_content_attributes (span , llm_request_type , response )
89199 if should_emit_events ():
90200 emit_response_events (event_logger , llm_request_type , response )
91- else :
92- set_response_attributes (span , llm_request_type , response )
93201
94202
95203@_with_tracer_wrapper
@@ -108,6 +216,55 @@ def _wrap(
108216 ):
109217 return wrapped (* args , ** kwargs )
110218
219+ name = to_wrap .get ("span_name" )
220+ llm_request_type = _llm_request_type_by_method (to_wrap .get ("method" ))
221+ span = tracer .start_span (
222+ name ,
223+ kind = SpanKind .CLIENT ,
224+ attributes = {
225+ SpanAttributes .LLM_SYSTEM : "Cohere" ,
226+ SpanAttributes .LLM_REQUEST_TYPE : llm_request_type .value ,
227+ },
228+ )
229+
230+ with use_span (span , end_on_exit = False ):
231+ set_span_request_attributes (span , kwargs )
232+ _handle_input_content (span , event_logger , llm_request_type , kwargs )
233+
234+ try :
235+ response = wrapped (* args , ** kwargs )
236+ except Exception as e :
237+ if span .is_recording ():
238+ span .set_status (Status (StatusCode .ERROR , str (e )))
239+ span .record_exception (e )
240+ span .end ()
241+ raise
242+
243+ if to_wrap .get ("stream_process_func" ):
244+ return to_wrap .get ("stream_process_func" )(span , event_logger , llm_request_type , response )
245+
246+ set_span_response_attributes (span , response )
247+ _handle_response_content (span , event_logger , llm_request_type , response )
248+ span .end ()
249+ return response
250+
251+
252+ @_with_tracer_wrapper
253+ async def _awrap (
254+ tracer : Tracer ,
255+ event_logger : Union [EventLogger , None ],
256+ to_wrap ,
257+ wrapped ,
258+ instance ,
259+ args ,
260+ kwargs ,
261+ ):
262+ """Instruments and calls every function defined in TO_WRAP."""
263+ if context_api .get_value (_SUPPRESS_INSTRUMENTATION_KEY ) or context_api .get_value (
264+ SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
265+ ):
266+ return await wrapped (* args , ** kwargs )
267+
111268 name = to_wrap .get ("span_name" )
112269 llm_request_type = _llm_request_type_by_method (to_wrap .get ("method" ))
113270 with tracer .start_as_current_span (
@@ -119,12 +276,19 @@ def _wrap(
119276 },
120277 ) as span :
121278 set_span_request_attributes (span , kwargs )
122- _handle_input (span , event_logger , llm_request_type , kwargs )
279+ _handle_input_content (span , event_logger , llm_request_type , kwargs )
123280
124- response = wrapped (* args , ** kwargs )
281+ try :
282+ response = await wrapped (* args , ** kwargs )
283+ except Exception as e :
284+ if span .is_recording ():
285+ span .set_status (Status (StatusCode .ERROR , str (e )))
286+ span .record_exception (e )
287+ span .end ()
288+ raise
125289
126- if response :
127- _handle_response (span , event_logger , llm_request_type , response )
290+ set_span_response_attributes ( span , response )
291+ _handle_response_content (span , event_logger , llm_request_type , response )
128292
129293 return response
130294
@@ -151,18 +315,51 @@ def _instrument(self, **kwargs):
151315 __name__ , __version__ , event_logger_provider = event_logger_provider
152316 )
153317 for wrapped_method in WRAPPED_METHODS :
318+ wrap_module = wrapped_method .get ("module" )
154319 wrap_object = wrapped_method .get ("object" )
155320 wrap_method = wrapped_method .get ("method" )
156- wrap_function_wrapper (
157- "cohere.client" ,
158- f"{ wrap_object } .{ wrap_method } " ,
159- _wrap (tracer , event_logger , wrapped_method ),
160- )
321+ try :
322+ wrap_function_wrapper (
323+ wrap_module ,
324+ f"{ wrap_object } .{ wrap_method } " ,
325+ _wrap (tracer , event_logger , wrapped_method ),
326+ )
327+ except (ImportError , ModuleNotFoundError , AttributeError ):
328+ logger .debug (f"Failed to instrument { wrap_module } .{ wrap_object } .{ wrap_method } " )
329+
330+ for wrapped_method in WRAPPED_AMETHODS :
331+ wrap_module = wrapped_method .get ("module" )
332+ wrap_object = wrapped_method .get ("object" )
333+ wrap_method = wrapped_method .get ("method" )
334+ try :
335+ wrap_function_wrapper (
336+ wrap_module ,
337+ f"{ wrap_object } .{ wrap_method } " ,
338+ _awrap (tracer , event_logger , wrapped_method ),
339+ )
340+ except (ImportError , ModuleNotFoundError , AttributeError ):
341+ logger .debug (f"Failed to instrument { wrap_module } .{ wrap_object } .{ wrap_method } " )
161342
162343 def _uninstrument (self , ** kwargs ):
163344 for wrapped_method in WRAPPED_METHODS :
345+ wrap_module = wrapped_method .get ("module" )
164346 wrap_object = wrapped_method .get ("object" )
165- unwrap (
166- f"cohere.client.{ wrap_object } " ,
167- wrapped_method .get ("method" ),
168- )
347+ wrap_method = wrapped_method .get ("method" )
348+ try :
349+ unwrap (
350+ f"{ wrap_module } .{ wrap_object } " ,
351+ wrap_method ,
352+ )
353+ except (ImportError , ModuleNotFoundError , AttributeError ):
354+ logger .debug (f"Failed to uninstrument { wrap_module } .{ wrap_object } .{ wrap_method } " )
355+ for wrapped_method in WRAPPED_AMETHODS :
356+ wrap_module = wrapped_method .get ("module" )
357+ wrap_object = wrapped_method .get ("object" )
358+ wrap_method = wrapped_method .get ("method" )
359+ try :
360+ unwrap (
361+ f"{ wrap_module } .{ wrap_object } " ,
362+ wrap_method ,
363+ )
364+ except (ImportError , ModuleNotFoundError , AttributeError ):
365+ logger .debug (f"Failed to uninstrument { wrap_module } .{ wrap_object } .{ wrap_method } " )
0 commit comments