Skip to content

Commit f06847e

Browse files
committed
Add Mistral guidance
Signed-off-by: Julien Denize <[email protected]>
1 parent 1d367a7 commit f06847e

File tree

11 files changed

+1506
-77
lines changed

11 files changed

+1506
-77
lines changed

tests/tool_parsers/test_mistral_tool_parser.py

Lines changed: 1101 additions & 14 deletions
Large diffs are not rendered by default.

vllm/entrypoints/openai/chat_completion/serving.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
validate_request_params,
8282
)
8383
from vllm.tool_parsers import ToolParser
84-
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
84+
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall, MistralToolParser
8585
from vllm.tool_parsers.utils import partial_json_loads
8686
from vllm.utils.collection_utils import as_list
8787
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
@@ -142,6 +142,9 @@ def __init__(
142142
enable_auto_tools=enable_auto_tools,
143143
model_name=self.model_config.model,
144144
)
145+
if self.tool_parser == MistralToolParser and self.reasoning_parser is not None:
146+
self.tool_parser.reasoning = True
147+
145148
self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
146149

147150
self.enable_prompt_tokens_details = enable_prompt_tokens_details

vllm/entrypoints/openai/engine/serving.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@
110110
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
111111
from vllm.sampling_params import BeamSearchParams, SamplingParams
112112
from vllm.tokenizers import TokenizerLike
113+
from vllm.tokenizers.mistral import MistralTokenizer
113114
from vllm.tool_parsers import ToolParser
115+
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
114116
from vllm.tracing import (
115117
contains_trace_headers,
116118
extract_trace_headers,
@@ -1228,24 +1230,33 @@ def _parse_tool_calls_from_content(
12281230
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
12291231
content: str | None = None,
12301232
) -> tuple[list[FunctionCall] | None, str | None]:
1233+
is_mistral_flow = tool_parser_cls == MistralToolParser and isinstance(
1234+
tokenizer, MistralTokenizer
1235+
)
12311236
function_calls = list[FunctionCall]()
1232-
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
1237+
if (
1238+
request.tool_choice
1239+
and isinstance(request.tool_choice, ToolChoiceFunction)
1240+
and not is_mistral_flow
1241+
):
12331242
assert content is not None
12341243
# Forced Function Call
12351244
function_calls.append(
12361245
FunctionCall(name=request.tool_choice.name, arguments=content)
12371246
)
12381247
content = None # Clear content since tool is called.
1239-
elif request.tool_choice and isinstance(
1240-
request.tool_choice, ChatCompletionNamedToolChoiceParam
1248+
elif (
1249+
request.tool_choice
1250+
and isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
1251+
and not is_mistral_flow
12411252
):
12421253
assert content is not None
12431254
# Forced Function Call
12441255
function_calls.append(
12451256
FunctionCall(name=request.tool_choice.function.name, arguments=content)
12461257
)
12471258
content = None # Clear content since tool is called.
1248-
elif request.tool_choice == "required":
1259+
elif request.tool_choice == "required" and not is_mistral_flow:
12491260
assert content is not None
12501261
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
12511262
function_calls.extend(
@@ -1259,7 +1270,8 @@ def _parse_tool_calls_from_content(
12591270
)
12601271
content = None # Clear content since tool is called.
12611272
elif (
1262-
tool_parser_cls
1273+
is_mistral_flow
1274+
or tool_parser_cls
12631275
and enable_auto_tools
12641276
and (request.tool_choice == "auto" or request.tool_choice is None)
12651277
):
@@ -1270,6 +1282,7 @@ def _parse_tool_calls_from_content(
12701282

12711283
# Automatic Tool Call Parsing
12721284
try:
1285+
assert tool_parser_cls is not None
12731286
tool_parser = tool_parser_cls(tokenizer)
12741287
except RuntimeError as e:
12751288
logger.exception("Error in tool parser creation.")

vllm/sampling_params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class StructuredOutputsParams:
3737
regex: str | None = None
3838
choice: list[str] | None = None
3939
grammar: str | None = None
40+
lark: str | None = None
4041
json_object: bool | None = None
4142
# These are other options that can be set.
4243
disable_fallback: bool = False
@@ -58,6 +59,7 @@ def __post_init__(self):
5859
self.regex is not None,
5960
self.choice is not None,
6061
self.grammar is not None,
62+
self.lark is not None,
6163
self.json_object is not None,
6264
self.structural_tag is not None,
6365
]
@@ -84,6 +86,7 @@ def all_constraints_none(self) -> bool:
8486
"regex",
8587
"choice",
8688
"grammar",
89+
"lark",
8790
"json_object",
8891
"structural_tag",
8992
)
@@ -100,6 +103,7 @@ def all_non_structural_tag_constraints_none(self) -> bool:
100103
"regex",
101104
"choice",
102105
"grammar",
106+
"lark",
103107
"json_object",
104108
)
105109
)

vllm/tokenizers/mistral.py

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from pathlib import Path
44
from typing import TYPE_CHECKING, Any, cast, overload
55

6+
import llguidance as llg
7+
import regex as re
68
from mistral_common.protocol.instruct.request import (
79
ChatCompletionRequest as MistralChatCompletionRequest,
810
)
@@ -11,8 +13,15 @@
1113
from mistral_common.tokens.tokenizers.base import (
1214
SpecialTokenPolicy,
1315
SpecialTokens,
16+
Tokenizer,
17+
)
18+
from mistral_common.tokens.tokenizers.instruct import (
19+
InstructTokenizerBase,
20+
InstructTokenizerV13,
21+
)
22+
from mistral_common.tokens.tokenizers.mistral import (
23+
MistralTokenizer as MistralCommonTokenizer,
1424
)
15-
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
1625
from mistral_common.tokens.tokenizers.sentencepiece import (
1726
SentencePieceTokenizer,
1827
)
@@ -21,20 +30,21 @@
2130
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
2231
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
2332
from vllm.logger import init_logger
33+
from vllm.tokenizers.protocol import TokenizerLike
34+
35+
try:
36+
# Transformers v5
37+
from transformers.tokenization_mistral_common import MistralCommonBackend
38+
except ImportError:
39+
# Transformers v4
40+
from transformers.tokenization_mistral_common import (
41+
MistralCommonTokenizer as MistralCommonBackend,
42+
)
2443

25-
from .protocol import TokenizerLike
2644

2745
if TYPE_CHECKING:
2846
from transformers import BatchEncoding
2947

30-
try:
31-
# Transformers v5
32-
from transformers.tokenization_mistral_common import MistralCommonBackend
33-
except ImportError:
34-
# Transformers v4
35-
from transformers.tokenization_mistral_common import (
36-
MistralCommonTokenizer as MistralCommonBackend,
37-
)
3848

3949
logger = init_logger(__name__)
4050

@@ -217,15 +227,6 @@ def from_pretrained(
217227
download_dir: str | None = None,
218228
**kwargs,
219229
) -> "MistralTokenizer":
220-
try:
221-
# Transformers v5
222-
from transformers.tokenization_mistral_common import MistralCommonBackend
223-
except ImportError:
224-
# Transformers v4
225-
from transformers.tokenization_mistral_common import (
226-
MistralCommonTokenizer as MistralCommonBackend,
227-
)
228-
229230
tokenizer = MistralCommonBackend.from_pretrained(
230231
path_or_repo_id,
231232
*args,
@@ -240,10 +241,10 @@ def from_pretrained(
240241
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
241242
super().__init__()
242243

243-
self.transformers_tokenizer = tokenizer
244-
self.mistral = tokenizer.tokenizer
245-
self.instruct = self.mistral.instruct_tokenizer
246-
self.tokenizer = self.instruct.tokenizer
244+
self.transformers_tokenizer: MistralCommonBackend = tokenizer
245+
self.mistral: MistralCommonTokenizer = tokenizer.tokenizer
246+
self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer
247+
self.tokenizer: Tokenizer = self.instruct.tokenizer
247248

248249
mode = self.mistral._chat_completion_request_validator._mode
249250
if mode != ValidationMode.test:
@@ -509,7 +510,7 @@ def convert_ids_to_tokens(
509510
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
510511

511512
non_skip_special_tokens_ids = {
512-
self.tokenizer.get_control_token(SpecialTokens.tool_calls),
513+
self.tokenizer.get_special_token(SpecialTokens.tool_calls),
513514
}
514515
if isinstance(self.instruct, InstructTokenizerV13):
515516
if self.instruct.BEGIN_THINK:
@@ -541,3 +542,66 @@ def convert_ids_to_tokens(
541542
]
542543

543544
return tokens
545+
546+
547+
class MistralLLGTokenizer:
548+
"""Wraps a mistral tokenizer for use with llguidance."""
549+
550+
eos_token_id: int
551+
bos_token_id: int
552+
tokens: list[bytes]
553+
special_token_ids: list[int]
554+
555+
def __init__(self, tokenizer: MistralTokenizer) -> None:
556+
self._tokenizer = tokenizer.tokenizer
557+
self.eos_token_id = self._tokenizer.eos_id
558+
self.bos_token_id = self._tokenizer.bos_id
559+
560+
self.tokens: list[bytes] = []
561+
self.special_token_ids: list[int] = []
562+
563+
seen_special_tokens: set[str] = set()
564+
for i in range(self._tokenizer.n_words):
565+
# Convert square brackets to angle brackets for special tokens,
566+
# since llg only recognizes the latter.
567+
if self._tokenizer.is_special(i):
568+
token_rep = self._tokenizer.id_to_piece(i)
569+
if match := re.fullmatch(r"\[(.*)\]", token_rep):
570+
token_rep_llg = f"<{match.group(1)}>"
571+
else:
572+
token_rep_llg = token_rep
573+
574+
if not re.fullmatch(r"<.*>", token_rep_llg):
575+
raise ValueError(
576+
f"Invalid special token: {token_rep_llg} ({token_rep})"
577+
)
578+
assert token_rep_llg not in seen_special_tokens, (
579+
token_rep_llg,
580+
seen_special_tokens,
581+
)
582+
seen_special_tokens.add(token_rep_llg)
583+
self.special_token_ids.append(i)
584+
self.tokens.append(token_rep_llg.encode("utf-8"))
585+
else:
586+
token = self._tokenizer.id_to_byte_piece(i, SpecialTokenPolicy.RAISE)
587+
self.tokens.append(token)
588+
589+
assert len(self.special_token_ids) == self._tokenizer.num_special_tokens, (
590+
len(self.special_token_ids),
591+
self._tokenizer.num_special_tokens,
592+
)
593+
594+
def __call__(self, s: str, *args, **kwds) -> list[int]:
595+
# HACK: add a null byte to the start of the string to avoid the tokenizer
596+
# absorbing the first character of tokens that start with "▁".
597+
# we then ignore the first two tokens the "▁" and the null byte.
598+
# This gives us the pure tokenized text without SP shit.
599+
if isinstance(self._tokenizer, SentencePieceTokenizer):
600+
return self._tokenizer.encode("\00" + s, bos=False, eos=False)[2:]
601+
else:
602+
return self._tokenizer.encode(s, bos=False, eos=False)
603+
604+
605+
def guidance_tokenizer_from_mistral_tokenizer(tokenizer: Tokenizer) -> llg.LLTokenizer:
606+
tokenizer_data = MistralLLGTokenizer(tokenizer)
607+
return llg.LLTokenizer(llg.TokenizerWrapper(tokenizer_data))

0 commit comments

Comments
 (0)