Skip to content

Commit 6d8d0a2

Browse files
authored
Add think chunk (#21333)
Signed-off-by: Julien Denize <[email protected]>
1 parent 11ef7a6 commit 6d8d0a2

File tree

11 files changed

+682
-13
lines changed

11 files changed

+682
-13
lines changed

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pyzmq >= 25.0.0
3333
msgspec
3434
gguf >= 0.13.0
3535
importlib_metadata; python_version < '3.10'
36-
mistral_common[opencv] >= 1.8.0
36+
mistral_common[image,audio] >= 1.8.2
3737
opencv-python-headless >= 4.11.0 # required for video IO
3838
pyyaml
3939
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12

requirements/nightly_torch_test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jiwer # required for audio tests
2323
timm # required for internvl test
2424
transformers_stream_generator # required for qwen-vl test
2525
matplotlib # required for qwen-vl test
26-
mistral_common[opencv] >= 1.8.0 # required for voxtral test
26+
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
2727
num2words # required for smolvlm test
2828
opencv-python-headless >= 4.11.0 # required for video test
2929
datamodel_code_generator # required for minicpm3 test

requirements/test.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ torchvision==0.22.1
2828
transformers_stream_generator # required for qwen-vl test
2929
mamba_ssm # required for plamo2 test
3030
matplotlib # required for qwen-vl test
31-
mistral_common[opencv] >= 1.8.0 # required for voxtral test
31+
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
3232
num2words # required for smolvlm test
3333
open_clip_torch==2.32.0 # Required for nemotron_vl test
3434
opencv-python-headless >= 4.11.0 # required for video test

requirements/test.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ mbstrdecoder==1.1.3
447447
# typepy
448448
mdurl==0.1.2
449449
# via markdown-it-py
450-
mistral-common==1.8.0
450+
mistral-common==1.8.2
451451
# via -r requirements/test.in
452452
mlflow==2.22.0
453453
# via terratorch
@@ -999,8 +999,11 @@ soundfile==0.12.1
999999
# via
10001000
# -r requirements/test.in
10011001
# librosa
1002+
# mistral-common
10021003
soxr==0.5.0.post1
1003-
# via librosa
1004+
# via
1005+
# librosa
1006+
# mistral-common
10041007
sqlalchemy==2.0.41
10051008
# via
10061009
# alembic

tests/entrypoints/test_chat_utils.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from typing import Literal, Optional
77

88
import pytest
9+
from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy,
10+
SpecialTokens)
11+
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
12+
Tekkenizer)
913

1014
from vllm.assets.audio import AudioAsset
1115
from vllm.assets.image import ImageAsset
@@ -21,6 +25,7 @@
2125
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
2226
encode_video_base64)
2327
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
28+
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
2429

2530
from ..models.registry import HF_EXAMPLE_MODELS
2631
from ..utils import VLLM_PATH
@@ -1374,3 +1379,165 @@ def test_resolve_content_format_examples(template_path, expected_format):
13741379
)
13751380

13761381
assert resolved_format == expected_format
1382+
1383+
1384+
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config,
1385+
mistral_tokenizer):
1386+
messages = [{
1387+
"role":
1388+
"system",
1389+
"content": [{
1390+
"type": "text",
1391+
"text": "You are a helpful assistant."
1392+
}, {
1393+
"type":
1394+
"thinking",
1395+
"closed":
1396+
True,
1397+
"thinking":
1398+
"Only return the answer when you are confident."
1399+
}]
1400+
}, {
1401+
"role": "user",
1402+
"content": "What is 2+2?"
1403+
}, {
1404+
"role":
1405+
"assistant",
1406+
"content": [{
1407+
"type": "text",
1408+
"text": "Let me think about it."
1409+
}, {
1410+
"type": "thinking",
1411+
"closed": True,
1412+
"thinking": "2+2 = 4"
1413+
}, {
1414+
"type": "text",
1415+
"text": "The answer is 4.",
1416+
}],
1417+
}]
1418+
1419+
conversation_with_thinking, _ = parse_chat_messages(
1420+
messages,
1421+
mistral_model_config,
1422+
mistral_tokenizer,
1423+
content_format="openai",
1424+
)
1425+
1426+
expected_conversation = [{
1427+
"role":
1428+
"system",
1429+
"content": [{
1430+
"type": "text",
1431+
"text": "You are a helpful assistant."
1432+
}, {
1433+
"type": "text",
1434+
"text": "Only return the answer when you are confident."
1435+
}],
1436+
}, {
1437+
"role":
1438+
"user",
1439+
"content": [{
1440+
"type": "text",
1441+
"text": "What is 2+2?"
1442+
}],
1443+
}, {
1444+
"role":
1445+
"assistant",
1446+
"content": [
1447+
{
1448+
"type": "text",
1449+
"text": "Let me think about it."
1450+
},
1451+
{
1452+
"type": "text",
1453+
"text": "2+2 = 4"
1454+
},
1455+
{
1456+
"type": "text",
1457+
"text": "The answer is 4."
1458+
},
1459+
]
1460+
}]
1461+
1462+
assert conversation_with_thinking == expected_conversation
1463+
1464+
1465+
def test_apply_mistral_chat_template_thinking_chunk():
1466+
# Moved import here to avoid yapf and isort conflicts
1467+
from vllm.entrypoints.chat_utils import apply_mistral_chat_template
1468+
messages = [{
1469+
"role":
1470+
"system",
1471+
"content": [{
1472+
"type": "text",
1473+
"text": "You are a helpful assistant."
1474+
}, {
1475+
"type":
1476+
"thinking",
1477+
"closed":
1478+
True,
1479+
"thinking":
1480+
"Only return the answer when you are confident."
1481+
}]
1482+
}, {
1483+
"role": "user",
1484+
"content": "What is 2+2?"
1485+
}, {
1486+
"role":
1487+
"assistant",
1488+
"content": [{
1489+
"type": "text",
1490+
"text": "Let me think about it."
1491+
}, {
1492+
"type": "thinking",
1493+
"closed": True,
1494+
"thinking": "2+2 = 4"
1495+
}, {
1496+
"type": "text",
1497+
"text": "The answer is 4.",
1498+
}],
1499+
}, {
1500+
"role": "user",
1501+
"content": "Thanks, what is 3+3?"
1502+
}]
1503+
1504+
# TODO(Julien): upon model release change to a tokenizer already configured.
1505+
# =================================================================
1506+
mistral_tokenizer = MistralTokenizer.from_pretrained(
1507+
"mistralai/Devstral-Small-2507")
1508+
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
1509+
# Add think special tokens to the tokenizer
1510+
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
1511+
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
1512+
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
1513+
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
1514+
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
1515+
k: v
1516+
for k, v in
1517+
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
1518+
if v not in {35, 36}
1519+
}
1520+
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
1521+
SpecialTokens.begin_think.value] = 35
1522+
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
1523+
SpecialTokens.end_think.value] = 36
1524+
mistral_tokenizer.instruct.BEGIN_THINK = 35
1525+
mistral_tokenizer.instruct.END_THINK = 36
1526+
# =================================================================
1527+
1528+
tokens_ids = apply_mistral_chat_template(mistral_tokenizer,
1529+
messages,
1530+
chat_template=None,
1531+
tools=None)
1532+
1533+
string_tokens = mistral_tokenizer.mistral.decode(
1534+
tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP)
1535+
1536+
expected_tokens = (
1537+
r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
1538+
r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
1539+
r"[INST]What is 2+2?[/INST]"
1540+
r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
1541+
r"[INST]Thanks, what is 3+3?[/INST]")
1542+
1543+
assert string_tokens == expected_tokens

0 commit comments

Comments
 (0)