Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,10 @@
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.llms.mistralai.utils import (
is_mistralai_function_calling_model,
is_mistralai_code_model,
mistralai_modelname_to_contextsize,
MISTRAL_AI_REASONING_MODELS,
THINKING_REGEX,
THINKING_START_REGEX,
)
from llama_index.llms.mistralai.helper import MistralHelper

from mistralai import Mistral
from mistralai.models import ToolCall
Expand Down Expand Up @@ -218,6 +215,7 @@ def __init__(
callback_manager = callback_manager or CallbackManager([])

api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "")
self.helper = MistralHelper(api_key=api_key)

if not api_key:
raise ValueError(
Expand Down Expand Up @@ -260,12 +258,12 @@ def class_name(cls) -> str:
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=mistralai_modelname_to_contextsize(self.model),
context_window=self.helper.modelname_to_contextsize(self.model),
num_output=self.max_tokens,
is_chat_model=True,
model_name=self.model,
random_seed=self.random_seed,
is_function_calling_model=is_mistralai_function_calling_model(self.model),
is_function_calling_model=self.helper.is_function_calling_model(self.model),
)

@property
Expand Down Expand Up @@ -310,7 +308,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
response = self._client.chat.complete(messages=messages, **all_kwargs)

additional_kwargs = {}
if self.model in MISTRAL_AI_REASONING_MODELS:
if self.model in self.helper.get_reasoning_models():
thinking_txt, response_txt = self._separate_thinking(
response.choices[0].message.content
)
Expand Down Expand Up @@ -371,7 +369,7 @@ def gen() -> ChatResponseGen:
content += content_delta

# decide whether to include thinking in deltas/responses
if self.model in MISTRAL_AI_REASONING_MODELS:
if self.model in self.helper.get_reasoning_models():
thinking_txt, response_txt = self._separate_thinking(content)

if thinking_txt:
Expand Down Expand Up @@ -415,7 +413,7 @@ async def achat(
)

additional_kwargs = {}
if self.model in MISTRAL_AI_REASONING_MODELS:
if self.model in self.helper.get_reasoning_models():
thinking_txt, response_txt = self._separate_thinking(
response.choices[0].message.content
)
Expand Down Expand Up @@ -475,7 +473,7 @@ async def gen() -> ChatResponseAsyncGen:
content += content_delta

# decide whether to include thinking in deltas/responses
if self.model in MISTRAL_AI_REASONING_MODELS:
if self.model in self.helper.get_reasoning_models():
thinking_txt, response_txt = self._separate_thinking(content)
if thinking_txt:
additional_kwargs["thinking"] = thinking_txt
Expand Down Expand Up @@ -583,7 +581,7 @@ def get_tool_calls_from_response(
def fill_in_middle(
self, prompt: str, suffix: str, stop: Optional[List[str]] = None
) -> CompletionResponse:
if not is_mistralai_code_model(self.model):
if not self.helper.is_code_model(self.model):
raise ValueError(
"Please provide code model from MistralAI. Currently supported code model is 'codestral-latest'."
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import requests
from typing import List, Dict, Any


def extract_model_name(model):
return model["id"]


class MistralHelper:
def __init__(self, api_key: str) -> None:
"""
Initialize MistralHelper with API key.

Args:
api_key: API key for MistralAI.

"""
self.api_key = api_key
self.refresh_models()

def refresh_models(self) -> None:
"""Refresh the list of available models from MistralAI."""
headers = {"Authorization": f"Bearer {self.api_key}"}
url = "https://api.mistral.ai/v1/models"

response = requests.get(url, headers=headers)
response.raise_for_status()
model_data: List[Dict[str, Any]] = response.json()["data"]

self.mistralai_models = {
model["id"]: model["max_context_length"] for model in model_data
}
self.function_calling_models = list(
map(
extract_model_name,
filter(
lambda m: m.get("capabilities", {}).get("function_calling"),
model_data,
),
)
)
self.coding_models = list(
map(
extract_model_name,
filter(
lambda m: m.get("capabilities", {}).get("completion_chat")
and "coding" in m.get("description", "").lower(),
model_data,
),
)
)
self.reasoning_models = list(
map(
extract_model_name,
filter(
lambda m: m.get("capabilities", {}).get("completion_chat")
and "reasoning" in m.get("description", "").lower(),
model_data,
),
)
)

def get_mistralai_models(self) -> Dict[str, int]:
"""
Get the dictionary of available MistralAI models and their context sizes.

Returns:
A dictionary mapping model names to their max_context_length.

"""
return self.mistralai_models

def get_function_calling_models(self) -> List[str]:
"""
Get the list of available MistralAI models that support function calling.

Returns:
A list of model names that support function calling.

"""
return self.function_calling_models

def get_coding_models(self) -> List[str]:
"""
Get the list of available MistralAI models that are designed for coding tasks.

Returns:
A list of model names that are coding models.

"""
return self.coding_models

def get_reasoning_models(self) -> List[str]:
"""
Get the list of available MistralAI models that are designed for reasoning tasks.

Returns:
A list of model names that are reasoning models.

"""
return self.reasoning_models

def modelname_to_contextsize(self, modelname: str) -> int:
"""
Get the context size for a given MistralAI model.

Args:
modelname: The name of the MistralAI model

Returns:
The context size (max_context_length) for the model

Raises:
ValueError: If the model is not found in the available models

"""
if modelname.startswith("ft:"):
modelname = modelname.split(":")[1]

if modelname not in self.mistralai_models:
raise ValueError(
f"Unknown model: {modelname}. Please provide a valid MistralAI model name."
"Known models are: " + ", ".join(self.mistralai_models.keys())
)

return self.mistralai_models[modelname]

def is_function_calling_model(self, modelname: str) -> bool:
"""
Check if a model supports function calling.

Args:
modelname: The name of the MistralAI model

Returns:
True if the model supports function calling, False otherwise

"""
return modelname in self.function_calling_models

def is_code_model(self, modelname: str) -> bool:
"""
Check if a model is specifically designed for coding tasks.

Args:
modelname: The name of the MistralAI model

Returns:
True if the model is a coding model, False otherwise

"""
return modelname in self.coding_models
Original file line number Diff line number Diff line change
@@ -1,76 +1,6 @@
import re
from typing import Dict

MISTRALAI_MODELS: Dict[str, int] = {
"mistral-tiny": 32000,
"mistral-small": 32000,
"mistral-medium": 32000,
"mistral-large": 131000,
"mistral-saba-latest": 32000,
"open-mixtral-8x7b": 32000,
"open-mistral-7b": 32000,
"open-mixtral-8x22b": 64000,
"mistral-small-latest": 32000,
"mistral-medium-latest": 32000,
"mistral-large-latest": 32000,
"codestral-latest": 256000,
"open-mistral-nemo-latest": 131000,
"ministral-8b-latest": 131000,
"ministral-3b-latest": 131000,
"pixtral-large-latest": 131000,
"pixtral-12b-2409": 131000,
"magistral-medium-2506": 40000,
"magistral-small-2506": 40000,
"magistral-medium-latest": 40000,
"magistral-small-latest": 40000,
}

MISTRALAI_FUNCTION_CALLING_MODELS = (
"mistral-large-latest",
"open-mixtral-8x22b",
"ministral-8b-latest",
"ministral-3b-latest",
"mistral-small-latest",
"codestral-latest",
"open-mistral-nemo-latest",
"pixtral-large-latest",
"pixtral-12b-2409",
"magistral-medium-2506",
"magistral-small-2506",
"magistral-medium-latest",
"magistral-small-latest",
)

MISTRAL_AI_REASONING_MODELS = (
"magistral-medium-2506",
"magistral-small-2506",
"magistral-medium-latest",
"magistral-small-latest",
)

MISTRALAI_CODE_MODELS = "codestral-latest"

THINKING_REGEX = re.compile(r"^<think>\n(.*?)\n</think>\n")
THINKING_START_REGEX = re.compile(r"^<think>\n")


def mistralai_modelname_to_contextsize(modelname: str) -> int:
# handling finetuned models
if modelname.startswith("ft:"):
modelname = modelname.split(":")[1]

if modelname not in MISTRALAI_MODELS:
raise ValueError(
f"Unknown model: {modelname}. Please provide a valid MistralAI model name."
"Known models are: " + ", ".join(MISTRALAI_MODELS.keys())
)

return MISTRALAI_MODELS[modelname]


def is_mistralai_function_calling_model(modelname: str) -> bool:
return modelname in MISTRALAI_FUNCTION_CALLING_MODELS


def is_mistralai_code_model(modelname: str) -> bool:
return modelname in MISTRALAI_CODE_MODELS
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,26 @@ def test_to_mistral_chunks(tmp_path: Path, image_url: str) -> None:
assert isinstance(chunks_with_path[1], ImageURLChunk)
assert isinstance(chunks_with_path[1].image_url, str)
assert chunks_with_path[1].image_url == f"data:image/png;base64,{expected_b64}"


@pytest.mark.skipif(
os.environ.get("MISTRAL_API_KEY") is None, reason="MISTRAL_API_KEY not set"
)
def helper_sanity():
from llama_index.llms.mistralai.helper import MistralHelper

helper = MistralHelper(api_key=os.environ.get("MISTRAL_API_KEY"))
assert isinstance(helper.get_mistralai_models(), dict)
assert isinstance(helper.get_function_calling_models(), list)
assert isinstance(helper.get_coding_models(), list)
assert isinstance(helper.get_reasoning_models(), list)
models = helper.get_mistralai_models()
for model in models:
context_size = helper.modelname_to_contextsize(model)
assert context_size == models[model]
is_fc = helper.is_function_calling_model(model)
if model in helper.get_function_calling_models():
assert is_fc
is_code = helper.is_code_model(model)
if model in helper.get_coding_models():
assert is_code
Loading