3535from opentelemetry .trace .status import Status , StatusCode
3636from wrapt import wrap_function_wrapper
3737
38- from mistralai .models . chat_completion import (
38+ from mistralai .models import (
3939 ChatCompletionResponse ,
40- ChatCompletionResponseChoice ,
41- ChatMessage ,
40+ ChatCompletionChoice ,
41+ AssistantMessage ,
42+ UserMessage ,
43+ SystemMessage ,
44+ UsageInfo ,
45+ EmbeddingResponse ,
4246)
43- from mistralai .models .common import UsageInfo
44- from mistralai .models .embeddings import EmbeddingResponse
4547
4648logger = logging .getLogger (__name__ )
4749
48- _instruments = ("mistralai >= 0.2.0, < 1 " ,)
50+ _instruments = ("mistralai >= 1.0.0 " ,)
4951
5052WRAPPED_METHODS = [
5153 {
52- "method" : "chat" ,
54+ "method" : "complete" ,
55+ "module" : "chat" ,
5356 "span_name" : "mistralai.chat" ,
5457 "streaming" : False ,
5558 },
5659 {
57- "method" : "chat_stream" ,
60+ "method" : "stream" ,
61+ "module" : "chat" ,
5862 "span_name" : "mistralai.chat" ,
5963 "streaming" : True ,
6064 },
6165 {
62- "method" : "embeddings" ,
66+ "method" : "create" ,
67+ "module" : "embeddings" ,
6368 "span_name" : "mistralai.embeddings" ,
6469 "streaming" : False ,
6570 },
@@ -92,7 +97,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
9297 message .role ,
9398 )
9499 else :
95- input = kwargs .get ("input" )
100+ input = kwargs .get ("input" ) or kwargs . get ( "inputs" )
96101
97102 if isinstance (input , str ):
98103 _set_span_attribute (
@@ -101,7 +106,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
101106 _set_span_attribute (
102107 span , f"{ SpanAttributes .LLM_PROMPTS } .0.content" , input
103108 )
104- else :
109+ elif input :
105110 for index , prompt in enumerate (input ):
106111 _set_span_attribute (
107112 span ,
@@ -205,20 +210,22 @@ def _accumulate_streaming_response(span, event_logger, llm_request_type, respons
205210 for res in response :
206211 yield res
207212
208- if res .model :
209- accumulated_response .model = res .model
210- if res .usage :
211- accumulated_response .usage = res .usage
213+ # Handle new CompletionEvent structure with .data attribute
214+ chunk_data = res .data if hasattr (res , 'data' ) else res
215+ if chunk_data .model :
216+ accumulated_response .model = chunk_data .model
217+ if chunk_data .usage :
218+ accumulated_response .usage = chunk_data .usage
212219 # Id is the same for all chunks, so it's safe to overwrite it every time
213- if res .id :
214- accumulated_response .id = res .id
220+ if chunk_data .id :
221+ accumulated_response .id = chunk_data .id
215222
216- for idx , choice in enumerate (res .choices ):
223+ for idx , choice in enumerate (chunk_data .choices ):
217224 if len (accumulated_response .choices ) <= idx :
218225 accumulated_response .choices .append (
219- ChatCompletionResponseChoice (
226+ ChatCompletionChoice (
220227 index = idx ,
221- message = ChatMessage (role = "assistant" , content = "" ),
228+ message = AssistantMessage (role = "assistant" , content = "" ),
222229 finish_reason = None ,
223230 )
224231 )
@@ -247,20 +254,22 @@ async def _aaccumulate_streaming_response(
247254 async for res in response :
248255 yield res
249256
250- if res .model :
251- accumulated_response .model = res .model
252- if res .usage :
253- accumulated_response .usage = res .usage
257+ # Handle new CompletionEvent structure with .data attribute
258+ chunk_data = res .data if hasattr (res , 'data' ) else res
259+ if chunk_data .model :
260+ accumulated_response .model = chunk_data .model
261+ if chunk_data .usage :
262+ accumulated_response .usage = chunk_data .usage
254263 # Id is the same for all chunks, so it's safe to overwrite it every time
255- if res .id :
256- accumulated_response .id = res .id
264+ if chunk_data .id :
265+ accumulated_response .id = chunk_data .id
257266
258- for idx , choice in enumerate (res .choices ):
267+ for idx , choice in enumerate (chunk_data .choices ):
259268 if len (accumulated_response .choices ) <= idx :
260269 accumulated_response .choices .append (
261- ChatCompletionResponseChoice (
270+ ChatCompletionChoice (
262271 index = idx ,
263- message = ChatMessage (role = "assistant" , content = "" ),
272+ message = AssistantMessage (role = "assistant" , content = "" ),
264273 finish_reason = None ,
265274 )
266275 )
@@ -287,9 +296,9 @@ def wrapper(wrapped, instance, args, kwargs):
287296
288297
289298def _llm_request_type_by_method (method_name ):
290- if method_name == "chat " or method_name == "chat_stream " :
299+ if method_name == "complete " or method_name == "stream " :
291300 return LLMRequestTypeValues .CHAT
292- elif method_name == "embeddings " :
301+ elif method_name == "create " :
293302 return LLMRequestTypeValues .EMBEDDING
294303 else :
295304 return LLMRequestTypeValues .UNKNOWN
@@ -301,7 +310,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):
301310 if method_wrapped == "mistralai.chat" :
302311 messages = args [0 ] if len (args ) > 0 else kwargs .get ("messages" , [])
303312 for message in messages :
304- if isinstance (message , ChatMessage ):
313+ if isinstance (message , ( UserMessage , AssistantMessage , SystemMessage ) ):
305314 role = message .role
306315 content = message .content
307316 elif isinstance (message , dict ):
@@ -313,7 +322,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):
313322
314323 # Handle embedding events
315324 elif method_wrapped == "mistralai.embeddings" :
316- embedding_input = args [0 ] if len (args ) > 0 else kwargs .get ("input" , [])
325+ embedding_input = args [0 ] if len (args ) > 0 else ( kwargs .get ("input" ) or kwargs . get ( "inputs" , []) )
317326 if isinstance (embedding_input , str ):
318327 emit_event (MessageEvent (content = embedding_input , role = "user" ), event_logger )
319328 elif isinstance (embedding_input , list ):
@@ -452,7 +461,7 @@ async def _awrap(
452461 _handle_input (span , event_logger , args , kwargs , to_wrap )
453462
454463 if to_wrap .get ("streaming" ):
455- response = wrapped (* args , ** kwargs )
464+ response = await wrapped (* args , ** kwargs )
456465 else :
457466 response = await wrapped (* args , ** kwargs )
458467
@@ -495,21 +504,23 @@ def _instrument(self, **kwargs):
495504
496505 for wrapped_method in WRAPPED_METHODS :
497506 wrap_method = wrapped_method .get ("method" )
507+ module_name = wrapped_method .get ("module" )
508+ # Wrap sync methods on the class
498509 wrap_function_wrapper (
499- "mistralai.client " ,
500- f"MistralClient .{ wrap_method } " ,
510+ f "mistralai.{ module_name } " ,
511+ f"{ module_name . capitalize () } .{ wrap_method } " ,
501512 _wrap (tracer , event_logger , wrapped_method ),
502513 )
514+ # Wrap async methods on the class
503515 wrap_function_wrapper (
504- "mistralai.async_client " ,
505- f"MistralAsyncClient. { wrap_method } " ,
516+ f "mistralai.{ module_name } " ,
517+ f"{ module_name . capitalize () } . { wrap_method } _async " ,
506518 _awrap (tracer , event_logger , wrapped_method ),
507519 )
508520
509521 def _uninstrument (self , ** kwargs ):
510522 for wrapped_method in WRAPPED_METHODS :
511- unwrap ("mistralai.client.MistralClient" , wrapped_method .get ("method" ))
512- unwrap (
513- "mistralai.async_client.MistralAsyncClient" ,
514- wrapped_method .get ("method" ),
515- )
523+ wrap_method = wrapped_method .get ("method" )
524+ module_name = wrapped_method .get ("module" )
525+ unwrap (f"mistralai.{ module_name } .{ module_name .capitalize ()} " , wrap_method )
526+ unwrap (f"mistralai.{ module_name } .{ module_name .capitalize ()} " , f"{ wrap_method } _async" )
0 commit comments