diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py index 903910dc2d..df84732bcd 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py @@ -43,13 +43,10 @@ from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.llms.mistralai.utils import ( - is_mistralai_function_calling_model, - is_mistralai_code_model, - mistralai_modelname_to_contextsize, - MISTRAL_AI_REASONING_MODELS, THINKING_REGEX, THINKING_START_REGEX, ) +from llama_index.llms.mistralai.helper import MistralHelper from mistralai import Mistral from mistralai.models import ToolCall @@ -218,6 +215,7 @@ def __init__( callback_manager = callback_manager or CallbackManager([]) api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "") + self.helper = MistralHelper(api_key=api_key) if not api_key: raise ValueError( @@ -260,12 +258,12 @@ def class_name(cls) -> str: @property def metadata(self) -> LLMMetadata: return LLMMetadata( - context_window=mistralai_modelname_to_contextsize(self.model), + context_window=self.helper.modelname_to_contextsize(self.model), num_output=self.max_tokens, is_chat_model=True, model_name=self.model, random_seed=self.random_seed, - is_function_calling_model=is_mistralai_function_calling_model(self.model), + is_function_calling_model=self.helper.is_function_calling_model(self.model), ) @property @@ -310,7 +308,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: response = self._client.chat.complete(messages=messages, **all_kwargs) additional_kwargs = {} - if self.model in MISTRAL_AI_REASONING_MODELS: + if self.model in self.helper.get_reasoning_models(): thinking_txt, response_txt = self._separate_thinking( response.choices[0].message.content ) @@ -371,7 +369,7 @@ def gen() -> ChatResponseGen: content += content_delta # decide whether to include thinking in deltas/responses - if self.model in MISTRAL_AI_REASONING_MODELS: + if self.model in self.helper.get_reasoning_models(): thinking_txt, response_txt = self._separate_thinking(content) if thinking_txt: @@ -415,7 +413,7 @@ async def achat( ) additional_kwargs = {} - if self.model in MISTRAL_AI_REASONING_MODELS: + if self.model in self.helper.get_reasoning_models(): thinking_txt, response_txt = self._separate_thinking( response.choices[0].message.content ) @@ -475,7 +473,7 @@ async def gen() -> ChatResponseAsyncGen: content += content_delta # decide whether to include thinking in deltas/responses - if self.model in MISTRAL_AI_REASONING_MODELS: + if self.model in self.helper.get_reasoning_models(): thinking_txt, response_txt = self._separate_thinking(content) if thinking_txt: additional_kwargs["thinking"] = thinking_txt @@ -583,7 +581,7 @@ def get_tool_calls_from_response( def fill_in_middle( self, prompt: str, suffix: str, stop: Optional[List[str]] = None ) -> CompletionResponse: - if not is_mistralai_code_model(self.model): + if not self.helper.is_code_model(self.model): raise ValueError( "Please provide code model from MistralAI. Currently supported code model is 'codestral-latest'." ) diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/helper.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/helper.py new file mode 100644 index 0000000000..d9bdc7e02f --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/helper.py @@ -0,0 +1,152 @@ +import requests +from typing import List, Dict, Any + + +def extract_model_name(model): + return model["id"] + + +class MistralHelper: + def __init__(self, api_key: str) -> None: + """ + Initialize MistralHelper with API key. + + Args: + api_key: API key for MistralAI. + + """ + self.api_key = api_key + self.refresh_models() + + def refresh_models(self) -> None: + """Refresh the list of available models from MistralAI.""" + headers = {"Authorization": f"Bearer {self.api_key}"} + url = "https://api.mistral.ai/v1/models" + + response = requests.get(url, headers=headers) + response.raise_for_status() + model_data: List[Dict[str, Any]] = response.json()["data"] + + self.mistralai_models = { + model["id"]: model["max_context_length"] for model in model_data + } + self.function_calling_models = list( + map( + extract_model_name, + filter( + lambda m: m.get("capabilities", {}).get("function_calling"), + model_data, + ), + ) + ) + self.coding_models = list( + map( + extract_model_name, + filter( + lambda m: m.get("capabilities", {}).get("completion_chat") + and "coding" in m.get("description", "").lower(), + model_data, + ), + ) + ) + self.reasoning_models = list( + map( + extract_model_name, + filter( + lambda m: m.get("capabilities", {}).get("completion_chat") + and "reasoning" in m.get("description", "").lower(), + model_data, + ), + ) + ) + + def get_mistralai_models(self) -> Dict[str, int]: + """ + Get the dictionary of available MistralAI models and their context sizes. + + Returns: + A dictionary mapping model names to their max_context_length. + + """ + return self.mistralai_models + + def get_function_calling_models(self) -> List[str]: + """ + Get the list of available MistralAI models that support function calling. + + Returns: + A list of model names that support function calling. + + """ + return self.function_calling_models + + def get_coding_models(self) -> List[str]: + """ + Get the list of available MistralAI models that are designed for coding tasks. + + Returns: + A list of model names that are coding models. + + """ + return self.coding_models + + def get_reasoning_models(self) -> List[str]: + """ + Get the list of available MistralAI models that are designed for reasoning tasks. + + Returns: + A list of model names that are reasoning models. + + """ + return self.reasoning_models + + def modelname_to_contextsize(self, modelname: str) -> int: + """ + Get the context size for a given MistralAI model. + + Args: + modelname: The name of the MistralAI model + + Returns: + The context size (max_context_length) for the model + + Raises: + ValueError: If the model is not found in the available models + + """ + if modelname.startswith("ft:"): + modelname = modelname.split(":")[1] + + if modelname not in self.mistralai_models: + raise ValueError( + f"Unknown model: {modelname}. Please provide a valid MistralAI model name." + "Known models are: " + ", ".join(self.mistralai_models.keys()) + ) + + return self.mistralai_models[modelname] + + def is_function_calling_model(self, modelname: str) -> bool: + """ + Check if a model supports function calling. + + Args: + modelname: The name of the MistralAI model + + Returns: + True if the model supports function calling, False otherwise + + """ + return modelname in self.function_calling_models + + def is_code_model(self, modelname: str) -> bool: + """ + Check if a model is specifically designed for coding tasks. + + Args: + modelname: The name of the MistralAI model + + Returns: + True if the model is a coding model, False otherwise + + """ + return modelname in self.coding_models diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py index 040e76bf14..99777a6be6 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py @@ -1,76 +1,6 @@ import re -from typing import Dict - -MISTRALAI_MODELS: Dict[str, int] = { - "mistral-tiny": 32000, - "mistral-small": 32000, - "mistral-medium": 32000, - "mistral-large": 131000, - "mistral-saba-latest": 32000, - "open-mixtral-8x7b": 32000, - "open-mistral-7b": 32000, - "open-mixtral-8x22b": 64000, - "mistral-small-latest": 32000, - "mistral-medium-latest": 32000, - "mistral-large-latest": 32000, - "codestral-latest": 256000, - "open-mistral-nemo-latest": 131000, - "ministral-8b-latest": 131000, - "ministral-3b-latest": 131000, - "pixtral-large-latest": 131000, - "pixtral-12b-2409": 131000, - "magistral-medium-2506": 40000, - "magistral-small-2506": 40000, - "magistral-medium-latest": 40000, - "magistral-small-latest": 40000, -} - -MISTRALAI_FUNCTION_CALLING_MODELS = ( - "mistral-large-latest", - "open-mixtral-8x22b", - "ministral-8b-latest", - "ministral-3b-latest", - "mistral-small-latest", - "codestral-latest", - "open-mistral-nemo-latest", - "pixtral-large-latest", - "pixtral-12b-2409", - "magistral-medium-2506", - "magistral-small-2506", - "magistral-medium-latest", - "magistral-small-latest", -) - -MISTRAL_AI_REASONING_MODELS = ( - "magistral-medium-2506", - "magistral-small-2506", - "magistral-medium-latest", - "magistral-small-latest", -) MISTRALAI_CODE_MODELS = "codestral-latest" THINKING_REGEX = re.compile(r"^\n(.*?)\n\n") THINKING_START_REGEX = re.compile(r"^\n") - - -def mistralai_modelname_to_contextsize(modelname: str) -> int: - # handling finetuned models - if modelname.startswith("ft:"): - modelname = modelname.split(":")[1] - - if modelname not in MISTRALAI_MODELS: - raise ValueError( - f"Unknown model: {modelname}. Please provide a valid MistralAI model name." - "Known models are: " + ", ".join(MISTRALAI_MODELS.keys()) - ) - - return MISTRALAI_MODELS[modelname] - - -def is_mistralai_function_calling_model(modelname: str) -> bool: - return modelname in MISTRALAI_FUNCTION_CALLING_MODELS - - -def is_mistralai_code_model(modelname: str) -> bool: - return modelname in MISTRALAI_CODE_MODELS diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py b/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py index c24d05cf38..db3e5e326c 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py @@ -186,3 +186,26 @@ def test_to_mistral_chunks(tmp_path: Path, image_url: str) -> None: assert isinstance(chunks_with_path[1], ImageURLChunk) assert isinstance(chunks_with_path[1].image_url, str) assert chunks_with_path[1].image_url == f"data:image/png;base64,{expected_b64}" + + +@pytest.mark.skipif( + os.environ.get("MISTRAL_API_KEY") is None, reason="MISTRAL_API_KEY not set" +) +def helper_sanity(): + from llama_index.llms.mistralai.helper import MistralHelper + + helper = MistralHelper(api_key=os.environ.get("MISTRAL_API_KEY")) + assert isinstance(helper.get_mistralai_models(), dict) + assert isinstance(helper.get_function_calling_models(), list) + assert isinstance(helper.get_coding_models(), list) + assert isinstance(helper.get_reasoning_models(), list) + models = helper.get_mistralai_models() + for model in models: + context_size = helper.modelname_to_contextsize(model) + assert context_size == models[model] + is_fc = helper.is_function_calling_model(model) + if model in helper.get_function_calling_models(): + assert is_fc + is_code = helper.is_code_model(model) + if model in helper.get_coding_models(): + assert is_code