Skip to content

Commit 57430fc

Browse files
juliendenizegemini-code-assist[bot]mgoin
authored
Default model load/config/tokenizer to mistral format if relevant files exist (#28659)
Signed-off-by: Julien Denize <[email protected]> Signed-off-by: Julien Denize <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: Michael Goin <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: mgoin <[email protected]>
1 parent c68c7b4 commit 57430fc

File tree

15 files changed

+230
-34
lines changed

15 files changed

+230
-34
lines changed

docs/features/tool_calling.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ Flags: `--tool-call-parser hermes`
142142
Supported models:
143143

144144
* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed)
145-
* Additional mistral function-calling models are compatible as well.
145+
* Additional Mistral function-calling models are compatible as well.
146146

147147
Known issues:
148148

@@ -158,12 +158,25 @@ Known issues:
158158

159159
Recommended flags:
160160

161-
1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend:
161+
1. To use the official Mistral AI's format:
162162

163-
`--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral`
163+
`--tool-call-parser mistral`
164164

165-
2. To use the default Transformers tokenization backend:
166-
`--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
165+
2. To use the Transformers format when available:
166+
167+
`--tokenizer_mode hf --config_format hf --load_format hf --tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
168+
169+
!!! note
170+
Models officially released by Mistral AI have two possible formats:
171+
172+
1. The official format that is used by default with `auto` or `mistral` arguments:
173+
174+
`--tokenizer_mode mistral --config_format mistral --load_format mistral`
175+
This format uses [mistral-common](https://github.com/mistralai/mistral-common), the Mistral AI's tokenizer backend.
176+
177+
2. The Transformers format, when available, that is used with `hf` arguments:
178+
179+
`--tokenizer_mode hf --config_format hf --load_format hf --chat-template examples/tool_chat_template_mistral_parallel.jinja`
167180

168181
### Llama Models (`llama3_json`)
169182

tests/models/language/generation/test_mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_mistral_format(
208208
with vllm_runner(
209209
model,
210210
dtype=dtype,
211-
tokenizer_mode="auto",
211+
tokenizer_mode="hf",
212212
load_format="safetensors",
213213
config_format="hf",
214214
) as hf_format_model:

tests/models/multimodal/test_mapping.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,24 @@ def test_hf_model_weights_mapper(model_arch: str):
5050
model_info.check_available_online(on_fail="skip")
5151
model_info.check_transformers_version(on_fail="skip")
5252

53+
is_mistral_model = model_arch in [
54+
"Mistral3ForConditionalGeneration",
55+
"PixtralForConditionalGeneration",
56+
"VoxtralForConditionalGeneration",
57+
]
58+
59+
if not is_mistral_model or model_info.tokenizer_mode == "mistral":
60+
tokenizer_mode = model_info.tokenizer_mode
61+
else:
62+
tokenizer_mode = "hf"
63+
5364
model_id = model_info.default
5465

5566
model_config = ModelConfig(
5667
model_id,
5768
tokenizer=model_info.tokenizer or model_id,
58-
tokenizer_mode=model_info.tokenizer_mode,
69+
tokenizer_mode=tokenizer_mode,
70+
config_format="hf",
5971
revision=model_info.revision,
6072
trust_remote_code=model_info.trust_remote_code,
6173
hf_overrides=model_info.hf_overrides,

tests/models/quantization/test_bitsandbytes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ def validate_generated_texts(
259259
tensor_parallel_size=vllm_tp_size,
260260
enforce_eager=False,
261261
default_torch_num_threads=1,
262+
tokenizer_mode="hf",
263+
load_format="hf",
264+
config_format="hf",
262265
) as llm:
263266
vllm_outputs = llm.generate_greedy(prompts, max_tokens)
264267
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")

tests/tool_use/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ def ensure_system_prompt(
128128
"arguments": [
129129
"--enforce-eager",
130130
"--no-enable-prefix-caching",
131+
"--tokenizer_mode",
132+
"hf",
133+
"--load_format",
134+
"hf",
135+
"--config_format",
136+
"hf",
131137
"--tool-call-parser",
132138
"mistral",
133139
"--chat-template",
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
5+
import tempfile
6+
from pathlib import Path
7+
from unittest.mock import MagicMock, call, patch
8+
9+
import pytest
10+
11+
from vllm.transformers_utils.config import list_filtered_repo_files
12+
13+
14+
@pytest.mark.parametrize(
15+
"allow_patterns,expected_relative_files",
16+
[
17+
(
18+
["*.json", "correct*.txt"],
19+
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
20+
),
21+
],
22+
)
23+
def test_list_filtered_repo_files(
24+
allow_patterns: list[str], expected_relative_files: list[str]
25+
):
26+
with tempfile.TemporaryDirectory() as tmp_dir:
27+
# Prep folder and files
28+
path_tmp_dir = Path(tmp_dir)
29+
subfolder = path_tmp_dir / "subfolder"
30+
subfolder.mkdir()
31+
(path_tmp_dir / "json_file.json").touch()
32+
(path_tmp_dir / "correct_2.txt").touch()
33+
(path_tmp_dir / "uncorrect.txt").touch()
34+
(path_tmp_dir / "uncorrect.jpeg").touch()
35+
(subfolder / "correct.txt").touch()
36+
(subfolder / "uncorrect_sub.txt").touch()
37+
38+
def _glob_path() -> list[str]:
39+
return [
40+
str(file.relative_to(path_tmp_dir))
41+
for file in path_tmp_dir.glob("**/*")
42+
if file.is_file()
43+
]
44+
45+
# Patch list_repo_files called by fn
46+
with patch(
47+
"vllm.transformers_utils.config.list_repo_files",
48+
MagicMock(return_value=_glob_path()),
49+
) as mock_list_repo_files:
50+
out_files = sorted(
51+
list_filtered_repo_files(
52+
tmp_dir, allow_patterns, "revision", "model", "token"
53+
)
54+
)
55+
assert out_files == sorted(expected_relative_files)
56+
assert mock_list_repo_files.call_count == 1
57+
assert mock_list_repo_files.call_args_list[0] == call(
58+
repo_id=tmp_dir,
59+
revision="revision",
60+
repo_type="model",
61+
token="token",
62+
)

tests/transformers_utils/test_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44

5-
from vllm.transformers_utils.utils import is_cloud_storage, is_gcs, is_s3
5+
from vllm.transformers_utils.utils import (
6+
is_cloud_storage,
7+
is_gcs,
8+
is_s3,
9+
)
610

711

812
def test_is_gcs():

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,15 @@
4646

4747
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
4848
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
49-
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
49+
# FIXME: Since "auto" will use Mistral tokenizer and these backends do not support
50+
# it, we skip these tests for now.
51+
# ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
52+
# ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", None),
53+
("mistralai/Ministral-8B-Instruct-2410", "guidance", "hf", None),
5054
pytest.param(
5155
"mistralai/Ministral-8B-Instruct-2410",
5256
"lm-format-enforcer",
53-
"auto",
57+
"hf",
5458
None,
5559
marks=pytest.mark.skip(
5660
reason=(
@@ -80,7 +84,7 @@
8084
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
8185
# ("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
8286
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG),
83-
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG),
87+
("mistralai/Ministral-8B-Instruct-2410", "guidance", "hf", NGRAM_SPEC_CONFIG),
8488
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
8589
("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", EAGLE_SPEC_CONFIG),
8690
]
@@ -151,6 +155,8 @@ def test_structured_output(
151155
),
152156
seed=120,
153157
tokenizer_mode=tokenizer_mode,
158+
load_format="auto" if not model_name.startswith("mistralai/") else "hf",
159+
config_format="auto" if not model_name.startswith("mistralai/") else "hf",
154160
speculative_config=speculative_config,
155161
)
156162

@@ -720,6 +726,8 @@ def test_structured_output_auto_mode(
720726
max_model_len=1024,
721727
structured_outputs_config=dict(backend="auto"),
722728
tokenizer_mode=tokenizer_mode,
729+
load_format="auto",
730+
config_format="auto",
723731
)
724732

725733
sampling_params = SamplingParams(

vllm/config/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
"transcription",
8282
"draft",
8383
]
84-
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
84+
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "custom"]
8585
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
8686
LogprobsMode = Literal[
8787
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
@@ -130,7 +130,8 @@ class ModelConfig:
130130
name or path will be used."""
131131
tokenizer_mode: TokenizerMode = "auto"
132132
"""Tokenizer mode:\n
133-
- "auto" will use the fast tokenizer if available.\n
133+
- "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n
134+
- "hf" will use the fast tokenizer if available.\n
134135
- "slow" will always use the slow tokenizer.\n
135136
- "mistral" will always use the tokenizer from `mistral_common`.\n
136137
- "custom" will use --tokenizer to select the preregistered tokenizer."""
@@ -241,8 +242,8 @@ class ModelConfig:
241242
first one."""
242243
config_format: str | ConfigFormat = "auto"
243244
"""The format of the model config to load:\n
244-
- "auto" will try to load the config in hf format if available else it
245-
will try to load in mistral format.\n
245+
- "auto" will try to load the config in hf format if available after trying
246+
to load in mistral format.\n
246247
- "hf" will load the config in hf format.\n
247248
- "mistral" will load the config in mistral format."""
248249
hf_token: bool | str | None = None

vllm/model_executor/model_loader/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# if a new load format is added here
3131
LoadFormats = Literal[
3232
"auto",
33+
"hf",
3334
"bitsandbytes",
3435
"dummy",
3536
"fastsafetensors",
@@ -45,6 +46,7 @@
4546
]
4647
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
4748
"auto": DefaultModelLoader,
49+
"hf": DefaultModelLoader,
4850
"bitsandbytes": BitsAndBytesModelLoader,
4951
"dummy": DummyModelLoader,
5052
"fastsafetensors": DefaultModelLoader,

0 commit comments

Comments
 (0)