Skip to content

Commit 11cd1ae

Browse files
[Tool parsing] Improve / correct mistral tool parsing (#10333)
1 parent 554af92 commit 11cd1ae

File tree

5 files changed

+172
-59
lines changed

5 files changed

+172
-59
lines changed

tests/models/decoder_only/language/test_mistral.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
33
Run `pytest tests/models/test_mistral.py`.
44
"""
5+
import copy
6+
57
import pytest
68

79
from vllm import SamplingParams
10+
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
11+
MistralToolParser)
812

913
from ...utils import check_logprobs_close
1014

@@ -58,17 +62,69 @@
5862
},
5963
"required": ["city", "state", "unit"]
6064
}
65+
},
66+
}, {
67+
"type": "function",
68+
"function": {
69+
"name": "rewrite",
70+
"description": "Rewrites text",
71+
"parameters": {
72+
"type": "object",
73+
"required": [],
74+
"properties": {
75+
"text": {
76+
"type": "string",
77+
"description": "The input text to rewrite."
78+
}
79+
}
80+
}
6181
}
6282
}]
63-
MSGS = [{
64-
"role":
65-
"user",
66-
"content": ("Can you tell me what the temperate"
67-
" will be in Dallas, in fahrenheit?")
68-
}]
69-
EXPECTED_FUNC_CALL = (
70-
'[{"name": "get_current_weather", "arguments": '
71-
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
83+
MSGS = [
84+
{
85+
"role": "system",
86+
"content": "You are an assistant."
87+
},
88+
{
89+
"role":
90+
"user",
91+
"content":
92+
"Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa
93+
},
94+
{
95+
"role":
96+
"assistant",
97+
"content":
98+
"",
99+
"tool_calls": [{
100+
"id": "bbc5b7ede",
101+
"type": "function",
102+
"function": {
103+
"name":
104+
"rewrite",
105+
"arguments":
106+
'{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa
107+
}
108+
}]
109+
},
110+
{
111+
"role": "tool",
112+
"content":
113+
"{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa
114+
"tool_call_id": "bbc5b7ede",
115+
"name": "rewrite"
116+
},
117+
{
118+
"role": "assistant",
119+
"content": "---\n\nMy English needs improving, maybe I make errors"
120+
},
121+
{
122+
"role":
123+
"user",
124+
"content": ("Can you tell me what the temperate"
125+
" will be in Dallas, in fahrenheit?")
126+
}
127+
]
72128

73129

74130
@pytest.mark.parametrize("model", MODELS)
@@ -175,8 +231,23 @@ def test_mistral_function_calling(
175231
tokenizer_mode="mistral",
176232
config_format="mistral",
177233
load_format="mistral") as vllm_model:
178-
outputs = vllm_model.model.chat(MSGS,
234+
235+
msgs = copy.deepcopy(MSGS)
236+
outputs = vllm_model.model.chat(msgs,
179237
tools=TOOLS,
180238
sampling_params=SAMPLING_PARAMS)
181239

182-
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
240+
tokenizer = vllm_model.model.get_tokenizer()
241+
tool_parser = MistralToolParser(tokenizer)
242+
243+
model_output = outputs[0].outputs[0].text.strip()
244+
assert model_output.startswith(tool_parser.bot_token), model_output
245+
parsed_message = tool_parser.extract_tool_calls(model_output, None)
246+
247+
assert parsed_message.tools_called
248+
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
249+
assert parsed_message.tool_calls[
250+
0].function.name == "get_current_weather"
251+
assert parsed_message.tool_calls[
252+
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
253+
assert parsed_message.content is None

vllm/entrypoints/openai/serving_chat.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.sampling_params import BeamSearchParams, SamplingParams
3131
from vllm.sequence import Logprob
3232
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
33+
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
3334
from vllm.utils import iterate_with_cancellation
3435

3536
logger = init_logger(__name__)
@@ -127,41 +128,11 @@ async def create_chat_completion(
127128
return self.create_error_response(
128129
"tool_choice = \"required\" is not supported!")
129130

130-
# NOTE: There is currently a bug in pydantic where attributes
131-
# declared as iterables are replaced in in the instances by
132-
# pydantic-core ValidatorIterator instance. In particular, this
133-
# affects tool_calls defined in ChatCompletionAssistantMessageParam
134-
# model:
135-
# see:
136-
# - https://github.com/pydantic/pydantic/issues/9467
137-
# As a result, tool_calls from assistant messages are never
138-
# deserialized in the request object if the tool_calls iterator is
139-
# not consumed. This affect messages passed to the MistralTokenizer
140-
# since no chat template is applied and therefore the tools_calls
141-
# iterator is not directly consumed.
142-
# Issue is tracked on Pydantic side, with resolution planned for
143-
# v2.11 release. In the meantime, the official workaround is to
144-
# consume the iterator so the tool_calls are correctly deserialized
145-
# in the OpenAI ChatCompletionAssistantMessageParam object
146-
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
147-
# Official Pydantic Issues:
148-
# - https://github.com/pydantic/pydantic/issues/9541
149-
# TODO: remove when pydantic v2.11 is released
131+
# because of issues with pydantic we need to potentially
132+
# re-serialize the tool_calls field of the request
133+
# for more info: see comment in `maybe_serialize_tool_calls`
150134
if isinstance(tokenizer, MistralTokenizer):
151-
for i, message in enumerate(request.messages):
152-
if message.get("role") == 'assistant':
153-
tool_calls_validator = message.get(
154-
"tool_calls", ().__iter__())
155-
validated_tool_calls = []
156-
while True:
157-
try:
158-
tool_call = next(
159-
tool_calls_validator) # type: ignore
160-
validated_tool_calls.append(tool_call)
161-
except StopIteration:
162-
break
163-
request.messages[i][
164-
"tool_calls"] = validated_tool_calls
135+
maybe_serialize_tool_calls(request)
165136

166137
if (request.tool_choice == "auto" and
167138
not (self.enable_auto_tools and tool_parser is not None)

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, tokenizer: AnyTokenizer):
6262
] # map what has been streamed for each tool so far to a list
6363
self.bot_token = "[TOOL_CALLS]"
6464
self.bot_token_id = self.vocab.get(self.bot_token)
65-
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
65+
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
6666
if self.bot_token_id is None:
6767
raise RuntimeError(
6868
"Mistral Tool Parser could not locate the tool call token in "
@@ -84,16 +84,25 @@ def extract_tool_calls(
8484
return ExtractedToolCallInformation(tools_called=False,
8585
tool_calls=[],
8686
content=model_output)
87+
88+
# first remove the BOT token
89+
tool_content = model_output.replace(self.bot_token, "").strip()
90+
8791
try:
8892

89-
# use a regex to find the tool call. remove the BOT token
90-
# and make sure to replace single quotes with double quotes
91-
raw_tool_call = self.tool_call_regex.findall(
92-
model_output.replace(self.bot_token, ""))[0]
93+
# we first try to directly load the json as parsing very nested
94+
# jsons is difficult
95+
try:
96+
function_call_arr = json.loads(tool_content)
97+
except json.JSONDecodeError:
98+
# use a regex to find the part corresponding to the tool call.
99+
# NOTE: This use case should not happen if the model is trained
100+
# correctly. It's a easy possible fix so it's included, but
101+
# can be brittle for very complex / highly nested tool calls
102+
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
103+
function_call_arr = json.loads(raw_tool_call)
93104

94-
# load the JSON, and then use it to build the Function and
95105
# Tool Call
96-
function_call_arr = json.loads(raw_tool_call)
97106
tool_calls: List[MistralToolCall] = [
98107
MistralToolCall(
99108
type="function",
@@ -116,7 +125,7 @@ def extract_tool_calls(
116125
# return information to just treat the tool call as regular JSON
117126
return ExtractedToolCallInformation(tools_called=False,
118127
tool_calls=[],
119-
content=model_output)
128+
content=tool_content)
120129

121130
def extract_tool_calls_streaming(
122131
self,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .mistral import MistralTokenizer
1+
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
22

3-
__all__ = ["MistralTokenizer"]
3+
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import huggingface_hub
88
from huggingface_hub import HfApi, hf_hub_download
99
from mistral_common.protocol.instruct.request import ChatCompletionRequest
10+
from mistral_common.tokens.tokenizers.base import SpecialTokens
1011
# yapf: disable
1112
from mistral_common.tokens.tokenizers.mistral import (
1213
MistralTokenizer as PublicMistralTokenizer)
@@ -29,6 +30,43 @@ class Encoding:
2930
input_ids: List[int]
3031

3132

33+
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
34+
# SEE: https://github.com/vllm-project/vllm/pull/9951
35+
# Credits go to: @gcalmettes
36+
# NOTE: There is currently a bug in pydantic where attributes
37+
# declared as iterables are replaced in in the instances by
38+
# pydantic-core ValidatorIterator instance. In particular, this
39+
# affects tool_calls defined in ChatCompletionAssistantMessageParam
40+
# model:
41+
# see:
42+
# - https://github.com/pydantic/pydantic/issues/9467
43+
# As a result, tool_calls from assistant messages are never
44+
# deserialized in the request object if the tool_calls iterator is
45+
# not consumed. This affect messages passed to the MistralTokenizer
46+
# since no chat template is applied and therefore the tools_calls
47+
# iterator is not directly consumed.
48+
# Issue is tracked on Pydantic side, with resolution planned for
49+
# v2.11 release. In the meantime, the official workaround is to
50+
# consume the iterator so the tool_calls are correctly deserialized
51+
# in the OpenAI ChatCompletionAssistantMessageParam object
52+
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
53+
# Official Pydantic Issues:
54+
# - https://github.com/pydantic/pydantic/issues/9541
55+
# TODO: remove when pydantic v2.11 is released
56+
for i, message in enumerate(request.messages):
57+
if message.get("role") == 'assistant':
58+
tool_calls_validator = message.get("tool_calls", ().__iter__())
59+
validated_tool_calls = []
60+
while True:
61+
try:
62+
tool_call = next(tool_calls_validator) # type: ignore
63+
validated_tool_calls.append(tool_call)
64+
except StopIteration:
65+
break
66+
67+
request.messages[i]["tool_calls"] = validated_tool_calls
68+
69+
3270
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
3371
repo_cache = os.path.join(
3472
huggingface_hub.constants.HF_HUB_CACHE,
@@ -222,7 +260,8 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:
222260
if self.is_tekken:
223261
tokens = [
224262
t for t in tokens
225-
if t not in self.tokenizer._all_special_tokens
263+
if (t is SpecialTokens.tool_calls
264+
or t not in self.tokenizer._all_special_tokens)
226265
]
227266

228267
if any(isinstance(t, bytes) for t in tokens):
@@ -246,7 +285,27 @@ def _token_to_id(t: str):
246285
else:
247286
decoded = "".join(tokens)
248287
else:
249-
decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
288+
# make sure certain special tokens like Tool calls are
289+
# not decoded
290+
special_tokens = {SpecialTokens.tool_calls}
291+
regular_tokens: List[str] = []
292+
decoded_list = []
293+
294+
for token in tokens:
295+
if token in special_tokens:
296+
if regular_tokens:
297+
decoded_list.append(
298+
self.tokenizer.decode(regular_tokens))
299+
regular_tokens = []
300+
decoded_list.append(token)
301+
else:
302+
regular_tokens.append(token)
303+
304+
if regular_tokens:
305+
decoded_list.append(
306+
self.decode(regular_tokens)) # type: ignore
307+
308+
decoded = ''.join(decoded_list)
250309

251310
return decoded
252311

@@ -274,8 +333,11 @@ def convert_ids_to_tokens(
274333
assert self.is_tekken or self.is_spm, type(self.tokenizer)
275334

276335
if self.is_tekken:
277-
# skip special tokens
278-
ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
336+
# skip special tokens except tool call
337+
ids = [
338+
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
339+
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
340+
]
279341

280342
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
281343

0 commit comments

Comments
 (0)