diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py
index 903910dc2d..df84732bcd 100644
--- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py
@@ -43,13 +43,10 @@
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.llms.mistralai.utils import (
- is_mistralai_function_calling_model,
- is_mistralai_code_model,
- mistralai_modelname_to_contextsize,
- MISTRAL_AI_REASONING_MODELS,
THINKING_REGEX,
THINKING_START_REGEX,
)
+from llama_index.llms.mistralai.helper import MistralHelper
from mistralai import Mistral
from mistralai.models import ToolCall
@@ -218,6 +215,7 @@ def __init__(
callback_manager = callback_manager or CallbackManager([])
api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "")
+ self.helper = MistralHelper(api_key=api_key)
if not api_key:
raise ValueError(
@@ -260,12 +258,12 @@ def class_name(cls) -> str:
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
- context_window=mistralai_modelname_to_contextsize(self.model),
+ context_window=self.helper.modelname_to_contextsize(self.model),
num_output=self.max_tokens,
is_chat_model=True,
model_name=self.model,
random_seed=self.random_seed,
- is_function_calling_model=is_mistralai_function_calling_model(self.model),
+ is_function_calling_model=self.helper.is_function_calling_model(self.model),
)
@property
@@ -310,7 +308,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
response = self._client.chat.complete(messages=messages, **all_kwargs)
additional_kwargs = {}
- if self.model in MISTRAL_AI_REASONING_MODELS:
+ if self.model in self.helper.get_reasoning_models():
thinking_txt, response_txt = self._separate_thinking(
response.choices[0].message.content
)
@@ -371,7 +369,7 @@ def gen() -> ChatResponseGen:
content += content_delta
# decide whether to include thinking in deltas/responses
- if self.model in MISTRAL_AI_REASONING_MODELS:
+ if self.model in self.helper.get_reasoning_models():
thinking_txt, response_txt = self._separate_thinking(content)
if thinking_txt:
@@ -415,7 +413,7 @@ async def achat(
)
additional_kwargs = {}
- if self.model in MISTRAL_AI_REASONING_MODELS:
+ if self.model in self.helper.get_reasoning_models():
thinking_txt, response_txt = self._separate_thinking(
response.choices[0].message.content
)
@@ -475,7 +473,7 @@ async def gen() -> ChatResponseAsyncGen:
content += content_delta
# decide whether to include thinking in deltas/responses
- if self.model in MISTRAL_AI_REASONING_MODELS:
+ if self.model in self.helper.get_reasoning_models():
thinking_txt, response_txt = self._separate_thinking(content)
if thinking_txt:
additional_kwargs["thinking"] = thinking_txt
@@ -583,7 +581,7 @@ def get_tool_calls_from_response(
def fill_in_middle(
self, prompt: str, suffix: str, stop: Optional[List[str]] = None
) -> CompletionResponse:
- if not is_mistralai_code_model(self.model):
+ if not self.helper.is_code_model(self.model):
raise ValueError(
"Please provide code model from MistralAI. Currently supported code model is 'codestral-latest'."
)
diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/helper.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/helper.py
new file mode 100644
index 0000000000..d9bdc7e02f
--- /dev/null
+++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/helper.py
@@ -0,0 +1,152 @@
+import requests
+from typing import List, Dict, Any
+
+
+def extract_model_name(model):
+ return model["id"]
+
+
+class MistralHelper:
+ def __init__(self, api_key: str) -> None:
+ """
+ Initialize MistralHelper with API key.
+
+ Args:
+ api_key: API key for MistralAI.
+
+ """
+ self.api_key = api_key
+ self.refresh_models()
+
+ def refresh_models(self) -> None:
+ """Refresh the list of available models from MistralAI."""
+ headers = {"Authorization": f"Bearer {self.api_key}"}
+ url = "https://api.mistral.ai/v1/models"
+
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+ model_data: List[Dict[str, Any]] = response.json()["data"]
+
+ self.mistralai_models = {
+ model["id"]: model["max_context_length"] for model in model_data
+ }
+ self.function_calling_models = list(
+ map(
+ extract_model_name,
+ filter(
+ lambda m: m.get("capabilities", {}).get("function_calling"),
+ model_data,
+ ),
+ )
+ )
+ self.coding_models = list(
+ map(
+ extract_model_name,
+ filter(
+ lambda m: m.get("capabilities", {}).get("completion_chat")
+ and "coding" in m.get("description", "").lower(),
+ model_data,
+ ),
+ )
+ )
+ self.reasoning_models = list(
+ map(
+ extract_model_name,
+ filter(
+ lambda m: m.get("capabilities", {}).get("completion_chat")
+ and "reasoning" in m.get("description", "").lower(),
+ model_data,
+ ),
+ )
+ )
+
+ def get_mistralai_models(self) -> Dict[str, int]:
+ """
+ Get the dictionary of available MistralAI models and their context sizes.
+
+ Returns:
+ A dictionary mapping model names to their max_context_length.
+
+ """
+ return self.mistralai_models
+
+ def get_function_calling_models(self) -> List[str]:
+ """
+ Get the list of available MistralAI models that support function calling.
+
+ Returns:
+ A list of model names that support function calling.
+
+ """
+ return self.function_calling_models
+
+ def get_coding_models(self) -> List[str]:
+ """
+ Get the list of available MistralAI models that are designed for coding tasks.
+
+ Returns:
+ A list of model names that are coding models.
+
+ """
+ return self.coding_models
+
+ def get_reasoning_models(self) -> List[str]:
+ """
+ Get the list of available MistralAI models that are designed for reasoning tasks.
+
+ Returns:
+ A list of model names that are reasoning models.
+
+ """
+ return self.reasoning_models
+
+ def modelname_to_contextsize(self, modelname: str) -> int:
+ """
+ Get the context size for a given MistralAI model.
+
+ Args:
+ modelname: The name of the MistralAI model
+
+ Returns:
+ The context size (max_context_length) for the model
+
+ Raises:
+ ValueError: If the model is not found in the available models
+
+ """
+ if modelname.startswith("ft:"):
+ modelname = modelname.split(":")[1]
+
+ if modelname not in self.mistralai_models:
+ raise ValueError(
+ f"Unknown model: {modelname}. Please provide a valid MistralAI model name."
+ "Known models are: " + ", ".join(self.mistralai_models.keys())
+ )
+
+ return self.mistralai_models[modelname]
+
+ def is_function_calling_model(self, modelname: str) -> bool:
+ """
+ Check if a model supports function calling.
+
+ Args:
+ modelname: The name of the MistralAI model
+
+ Returns:
+ True if the model supports function calling, False otherwise
+
+ """
+ return modelname in self.function_calling_models
+
+ def is_code_model(self, modelname: str) -> bool:
+ """
+ Check if a model is specifically designed for coding tasks.
+
+ Args:
+ modelname: The name of the MistralAI model
+
+ Returns:
+ True if the model is a coding model, False otherwise
+
+ """
+ return modelname in self.coding_models
diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py
index 040e76bf14..99777a6be6 100644
--- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py
+++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py
@@ -1,76 +1,6 @@
import re
-from typing import Dict
-
-MISTRALAI_MODELS: Dict[str, int] = {
- "mistral-tiny": 32000,
- "mistral-small": 32000,
- "mistral-medium": 32000,
- "mistral-large": 131000,
- "mistral-saba-latest": 32000,
- "open-mixtral-8x7b": 32000,
- "open-mistral-7b": 32000,
- "open-mixtral-8x22b": 64000,
- "mistral-small-latest": 32000,
- "mistral-medium-latest": 32000,
- "mistral-large-latest": 32000,
- "codestral-latest": 256000,
- "open-mistral-nemo-latest": 131000,
- "ministral-8b-latest": 131000,
- "ministral-3b-latest": 131000,
- "pixtral-large-latest": 131000,
- "pixtral-12b-2409": 131000,
- "magistral-medium-2506": 40000,
- "magistral-small-2506": 40000,
- "magistral-medium-latest": 40000,
- "magistral-small-latest": 40000,
-}
-
-MISTRALAI_FUNCTION_CALLING_MODELS = (
- "mistral-large-latest",
- "open-mixtral-8x22b",
- "ministral-8b-latest",
- "ministral-3b-latest",
- "mistral-small-latest",
- "codestral-latest",
- "open-mistral-nemo-latest",
- "pixtral-large-latest",
- "pixtral-12b-2409",
- "magistral-medium-2506",
- "magistral-small-2506",
- "magistral-medium-latest",
- "magistral-small-latest",
-)
-
-MISTRAL_AI_REASONING_MODELS = (
- "magistral-medium-2506",
- "magistral-small-2506",
- "magistral-medium-latest",
- "magistral-small-latest",
-)
MISTRALAI_CODE_MODELS = "codestral-latest"
THINKING_REGEX = re.compile(r"^\n(.*?)\n\n")
THINKING_START_REGEX = re.compile(r"^\n")
-
-
-def mistralai_modelname_to_contextsize(modelname: str) -> int:
- # handling finetuned models
- if modelname.startswith("ft:"):
- modelname = modelname.split(":")[1]
-
- if modelname not in MISTRALAI_MODELS:
- raise ValueError(
- f"Unknown model: {modelname}. Please provide a valid MistralAI model name."
- "Known models are: " + ", ".join(MISTRALAI_MODELS.keys())
- )
-
- return MISTRALAI_MODELS[modelname]
-
-
-def is_mistralai_function_calling_model(modelname: str) -> bool:
- return modelname in MISTRALAI_FUNCTION_CALLING_MODELS
-
-
-def is_mistralai_code_model(modelname: str) -> bool:
- return modelname in MISTRALAI_CODE_MODELS
diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py b/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py
index c24d05cf38..db3e5e326c 100644
--- a/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py
+++ b/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py
@@ -186,3 +186,26 @@ def test_to_mistral_chunks(tmp_path: Path, image_url: str) -> None:
assert isinstance(chunks_with_path[1], ImageURLChunk)
assert isinstance(chunks_with_path[1].image_url, str)
assert chunks_with_path[1].image_url == f"data:image/png;base64,{expected_b64}"
+
+
+@pytest.mark.skipif(
+ os.environ.get("MISTRAL_API_KEY") is None, reason="MISTRAL_API_KEY not set"
+)
+def helper_sanity():
+ from llama_index.llms.mistralai.helper import MistralHelper
+
+ helper = MistralHelper(api_key=os.environ.get("MISTRAL_API_KEY"))
+ assert isinstance(helper.get_mistralai_models(), dict)
+ assert isinstance(helper.get_function_calling_models(), list)
+ assert isinstance(helper.get_coding_models(), list)
+ assert isinstance(helper.get_reasoning_models(), list)
+ models = helper.get_mistralai_models()
+ for model in models:
+ context_size = helper.modelname_to_contextsize(model)
+ assert context_size == models[model]
+ is_fc = helper.is_function_calling_model(model)
+ if model in helper.get_function_calling_models():
+ assert is_fc
+ is_code = helper.is_code_model(model)
+ if model in helper.get_coding_models():
+ assert is_code