Skip to content

Commit 422e793

Browse files
[Bugfix] Add support for <tool_call> format in streaming mode for XLAM Tool Parser (#22769)
Signed-off-by: Devon Peroutky <[email protected]>
1 parent 1cb39db commit 422e793

File tree

2 files changed

+297
-25
lines changed

2 files changed

+297
-25
lines changed

tests/tool_use/test_xlam_tool_parser.py

Lines changed: 216 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import json
5+
from collections.abc import Generator
6+
from typing import Optional
57

68
import pytest
79

8-
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
10+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
11+
DeltaMessage, FunctionCall,
12+
ToolCall)
913
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
10-
from vllm.transformers_utils.tokenizer import get_tokenizer
14+
from vllm.transformers_utils.detokenizer import detokenize_incrementally
15+
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
1116

1217
# Use a common model that is likely to be available
1318
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
@@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
3641
assert actual_tool_call.function == expected_tool_call.function
3742

3843

44+
def stream_delta_message_generator(
45+
xlam_tool_parser: xLAMToolParser,
46+
xlam_tokenizer: AnyTokenizer,
47+
model_output: str,
48+
request: Optional[ChatCompletionRequest] = None,
49+
) -> Generator[DeltaMessage, None, None]:
50+
all_token_ids = xlam_tokenizer.encode(model_output,
51+
add_special_tokens=False)
52+
53+
previous_text = ""
54+
previous_tokens = None
55+
prefix_offset = 0
56+
read_offset = 0
57+
for i, delta_token in enumerate(all_token_ids):
58+
delta_token_ids = [delta_token]
59+
previous_token_ids = all_token_ids[:i]
60+
current_token_ids = all_token_ids[:i + 1]
61+
62+
(new_tokens, delta_text, new_prefix_offset,
63+
new_read_offset) = (detokenize_incrementally(
64+
tokenizer=xlam_tokenizer,
65+
all_input_ids=current_token_ids,
66+
prev_tokens=previous_tokens,
67+
prefix_offset=prefix_offset,
68+
read_offset=read_offset,
69+
skip_special_tokens=False,
70+
spaces_between_special_tokens=True,
71+
))
72+
73+
current_text = previous_text + delta_text
74+
75+
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
76+
previous_text,
77+
current_text,
78+
delta_text,
79+
previous_token_ids,
80+
current_token_ids,
81+
delta_token_ids,
82+
request=request,
83+
)
84+
if delta_message:
85+
yield delta_message
86+
87+
previous_text = current_text
88+
previous_tokens = (previous_tokens +
89+
new_tokens if previous_tokens else new_tokens)
90+
prefix_offset = new_prefix_offset
91+
read_offset = new_read_offset
92+
93+
3994
def test_extract_tool_calls_no_tools(xlam_tool_parser):
4095
model_output = "This is a test"
4196
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
@@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
51106
"single_tool_with_think_tag",
52107
"single_tool_with_json_code_block",
53108
"single_tool_with_tool_calls_tag",
109+
"single_tool_with_tool_call_xml_tags",
54110
],
55111
argnames=["model_output", "expected_tool_calls", "expected_content"],
56112
argvalues=[
@@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
118174
],
119175
"I'll check the weather for you.",
120176
),
177+
(
178+
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
179+
[
180+
ToolCall(function=FunctionCall(
181+
name="get_current_weather",
182+
arguments=json.dumps({
183+
"city": "Dallas",
184+
"state": "TX",
185+
"unit": "fahrenheit",
186+
}),
187+
))
188+
],
189+
"I'll help you check the weather.",
190+
),
121191
],
122192
)
123193
def test_extract_tool_calls(xlam_tool_parser, model_output,
@@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
245315
assert hasattr(result, "tool_calls")
246316
assert len(result.tool_calls) == 1
247317
assert result.tool_calls[0].function.name == "get_current_weather"
318+
319+
320+
@pytest.mark.parametrize(
321+
ids=[
322+
"parallel_tool_calls",
323+
"single_tool_with_think_tag",
324+
"single_tool_with_json_code_block",
325+
"single_tool_with_tool_calls_tag",
326+
"single_tool_with_tool_call_xml_tags",
327+
],
328+
argnames=["model_output", "expected_tool_calls", "expected_content"],
329+
argvalues=[
330+
(
331+
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
332+
[
333+
ToolCall(function=FunctionCall(
334+
name="get_current_weather",
335+
arguments=json.dumps({
336+
"city": "Dallas",
337+
"state": "TX",
338+
"unit": "fahrenheit",
339+
}),
340+
)),
341+
ToolCall(function=FunctionCall(
342+
name="get_current_weather",
343+
arguments=json.dumps({
344+
"city": "Orlando",
345+
"state": "FL",
346+
"unit": "fahrenheit",
347+
}),
348+
)),
349+
],
350+
"",
351+
),
352+
(
353+
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
354+
[
355+
ToolCall(function=FunctionCall(
356+
name="get_current_weather",
357+
arguments=json.dumps({
358+
"city": "Dallas",
359+
"state": "TX",
360+
"unit": "fahrenheit",
361+
}),
362+
))
363+
],
364+
"<think>I'll help you with that.</think>",
365+
),
366+
(
367+
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
368+
[
369+
ToolCall(function=FunctionCall(
370+
name="get_current_weather",
371+
arguments=json.dumps({
372+
"city": "Dallas",
373+
"state": "TX",
374+
"unit": "fahrenheit",
375+
}),
376+
))
377+
],
378+
"",
379+
),
380+
(
381+
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
382+
[
383+
ToolCall(function=FunctionCall(
384+
name="get_current_weather",
385+
arguments=json.dumps({
386+
"city": "Dallas",
387+
"state": "TX",
388+
"unit": "fahrenheit",
389+
}),
390+
))
391+
],
392+
"",
393+
),
394+
(
395+
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
396+
[
397+
ToolCall(function=FunctionCall(
398+
name="get_current_weather",
399+
arguments=json.dumps({
400+
"city": "Dallas",
401+
"state": "TX",
402+
"unit": "fahrenheit",
403+
}),
404+
))
405+
],
406+
"I can help with that.",
407+
),
408+
],
409+
)
410+
def test_extract_tool_calls_streaming_incremental(
411+
xlam_tool_parser,
412+
xlam_tokenizer,
413+
model_output,
414+
expected_tool_calls,
415+
expected_content,
416+
):
417+
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
418+
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
419+
420+
chunks = []
421+
for delta_message in stream_delta_message_generator(
422+
xlam_tool_parser, xlam_tokenizer, model_output, request):
423+
chunks.append(delta_message)
424+
425+
# Should have multiple chunks
426+
assert len(chunks) >= 3
427+
428+
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
429+
header_found = False
430+
expected_first_tool = expected_tool_calls[0]
431+
for chunk in chunks:
432+
if chunk.tool_calls and chunk.tool_calls[0].id:
433+
header_found = True
434+
assert (chunk.tool_calls[0].function.name ==
435+
expected_first_tool.function.name)
436+
assert chunk.tool_calls[0].type == "function"
437+
# Arguments may be empty initially or None
438+
if chunk.tool_calls[0].function.arguments is not None:
439+
# If present, should be empty string initially
440+
assert chunk.tool_calls[0].function.arguments == ""
441+
break
442+
assert header_found
443+
444+
# Should have chunks with incremental arguments
445+
arg_chunks = []
446+
for chunk in chunks:
447+
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
448+
and chunk.tool_calls[0].function.arguments != ""
449+
and chunk.tool_calls[0].index ==
450+
0 # Only collect arguments from the first tool call
451+
):
452+
arg_chunks.append(chunk.tool_calls[0].function.arguments)
453+
454+
# Arguments should be streamed incrementally
455+
assert len(arg_chunks) > 1
456+
457+
# Concatenated arguments should form valid JSON for the first tool call
458+
full_args = "".join(arg_chunks)
459+
parsed_args = json.loads(full_args)
460+
expected_args = json.loads(expected_first_tool.function.arguments)
461+
assert parsed_args == expected_args

0 commit comments

Comments
 (0)