22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import json
5+ from collections .abc import Generator
6+ from typing import Optional
57
68import pytest
79
8- from vllm .entrypoints .openai .protocol import FunctionCall , ToolCall
10+ from vllm .entrypoints .openai .protocol import (ChatCompletionRequest ,
11+ DeltaMessage , FunctionCall ,
12+ ToolCall )
913from vllm .entrypoints .openai .tool_parsers import xLAMToolParser
10- from vllm .transformers_utils .tokenizer import get_tokenizer
14+ from vllm .transformers_utils .detokenizer import detokenize_incrementally
15+ from vllm .transformers_utils .tokenizer import AnyTokenizer , get_tokenizer
1116
1217# Use a common model that is likely to be available
1318MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
@@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
3641 assert actual_tool_call .function == expected_tool_call .function
3742
3843
44+ def stream_delta_message_generator (
45+ xlam_tool_parser : xLAMToolParser ,
46+ xlam_tokenizer : AnyTokenizer ,
47+ model_output : str ,
48+ request : Optional [ChatCompletionRequest ] = None ,
49+ ) -> Generator [DeltaMessage , None , None ]:
50+ all_token_ids = xlam_tokenizer .encode (model_output ,
51+ add_special_tokens = False )
52+
53+ previous_text = ""
54+ previous_tokens = None
55+ prefix_offset = 0
56+ read_offset = 0
57+ for i , delta_token in enumerate (all_token_ids ):
58+ delta_token_ids = [delta_token ]
59+ previous_token_ids = all_token_ids [:i ]
60+ current_token_ids = all_token_ids [:i + 1 ]
61+
62+ (new_tokens , delta_text , new_prefix_offset ,
63+ new_read_offset ) = (detokenize_incrementally (
64+ tokenizer = xlam_tokenizer ,
65+ all_input_ids = current_token_ids ,
66+ prev_tokens = previous_tokens ,
67+ prefix_offset = prefix_offset ,
68+ read_offset = read_offset ,
69+ skip_special_tokens = False ,
70+ spaces_between_special_tokens = True ,
71+ ))
72+
73+ current_text = previous_text + delta_text
74+
75+ delta_message = xlam_tool_parser .extract_tool_calls_streaming (
76+ previous_text ,
77+ current_text ,
78+ delta_text ,
79+ previous_token_ids ,
80+ current_token_ids ,
81+ delta_token_ids ,
82+ request = request ,
83+ )
84+ if delta_message :
85+ yield delta_message
86+
87+ previous_text = current_text
88+ previous_tokens = (previous_tokens +
89+ new_tokens if previous_tokens else new_tokens )
90+ prefix_offset = new_prefix_offset
91+ read_offset = new_read_offset
92+
93+
3994def test_extract_tool_calls_no_tools (xlam_tool_parser ):
4095 model_output = "This is a test"
4196 extracted_tool_calls = xlam_tool_parser .extract_tool_calls (
@@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
51106 "single_tool_with_think_tag" ,
52107 "single_tool_with_json_code_block" ,
53108 "single_tool_with_tool_calls_tag" ,
109+ "single_tool_with_tool_call_xml_tags" ,
54110 ],
55111 argnames = ["model_output" , "expected_tool_calls" , "expected_content" ],
56112 argvalues = [
@@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
118174 ],
119175 "I'll check the weather for you." ,
120176 ),
177+ (
178+ """I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""" , # noqa: E501
179+ [
180+ ToolCall (function = FunctionCall (
181+ name = "get_current_weather" ,
182+ arguments = json .dumps ({
183+ "city" : "Dallas" ,
184+ "state" : "TX" ,
185+ "unit" : "fahrenheit" ,
186+ }),
187+ ))
188+ ],
189+ "I'll help you check the weather." ,
190+ ),
121191 ],
122192)
123193def test_extract_tool_calls (xlam_tool_parser , model_output ,
@@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
245315 assert hasattr (result , "tool_calls" )
246316 assert len (result .tool_calls ) == 1
247317 assert result .tool_calls [0 ].function .name == "get_current_weather"
318+
319+
320+ @pytest .mark .parametrize (
321+ ids = [
322+ "parallel_tool_calls" ,
323+ "single_tool_with_think_tag" ,
324+ "single_tool_with_json_code_block" ,
325+ "single_tool_with_tool_calls_tag" ,
326+ "single_tool_with_tool_call_xml_tags" ,
327+ ],
328+ argnames = ["model_output" , "expected_tool_calls" , "expected_content" ],
329+ argvalues = [
330+ (
331+ """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""" , # noqa: E501
332+ [
333+ ToolCall (function = FunctionCall (
334+ name = "get_current_weather" ,
335+ arguments = json .dumps ({
336+ "city" : "Dallas" ,
337+ "state" : "TX" ,
338+ "unit" : "fahrenheit" ,
339+ }),
340+ )),
341+ ToolCall (function = FunctionCall (
342+ name = "get_current_weather" ,
343+ arguments = json .dumps ({
344+ "city" : "Orlando" ,
345+ "state" : "FL" ,
346+ "unit" : "fahrenheit" ,
347+ }),
348+ )),
349+ ],
350+ "" ,
351+ ),
352+ (
353+ """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""" , # noqa: E501
354+ [
355+ ToolCall (function = FunctionCall (
356+ name = "get_current_weather" ,
357+ arguments = json .dumps ({
358+ "city" : "Dallas" ,
359+ "state" : "TX" ,
360+ "unit" : "fahrenheit" ,
361+ }),
362+ ))
363+ ],
364+ "<think>I'll help you with that.</think>" ,
365+ ),
366+ (
367+ """```json\n [{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n ```""" , # noqa: E501
368+ [
369+ ToolCall (function = FunctionCall (
370+ name = "get_current_weather" ,
371+ arguments = json .dumps ({
372+ "city" : "Dallas" ,
373+ "state" : "TX" ,
374+ "unit" : "fahrenheit" ,
375+ }),
376+ ))
377+ ],
378+ "" ,
379+ ),
380+ (
381+ """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""" , # noqa: E501
382+ [
383+ ToolCall (function = FunctionCall (
384+ name = "get_current_weather" ,
385+ arguments = json .dumps ({
386+ "city" : "Dallas" ,
387+ "state" : "TX" ,
388+ "unit" : "fahrenheit" ,
389+ }),
390+ ))
391+ ],
392+ "" ,
393+ ),
394+ (
395+ """I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""" , # noqa: E501
396+ [
397+ ToolCall (function = FunctionCall (
398+ name = "get_current_weather" ,
399+ arguments = json .dumps ({
400+ "city" : "Dallas" ,
401+ "state" : "TX" ,
402+ "unit" : "fahrenheit" ,
403+ }),
404+ ))
405+ ],
406+ "I can help with that." ,
407+ ),
408+ ],
409+ )
410+ def test_extract_tool_calls_streaming_incremental (
411+ xlam_tool_parser ,
412+ xlam_tokenizer ,
413+ model_output ,
414+ expected_tool_calls ,
415+ expected_content ,
416+ ):
417+ """Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
418+ request = ChatCompletionRequest (model = MODEL , messages = [], tools = [])
419+
420+ chunks = []
421+ for delta_message in stream_delta_message_generator (
422+ xlam_tool_parser , xlam_tokenizer , model_output , request ):
423+ chunks .append (delta_message )
424+
425+ # Should have multiple chunks
426+ assert len (chunks ) >= 3
427+
428+ # Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
429+ header_found = False
430+ expected_first_tool = expected_tool_calls [0 ]
431+ for chunk in chunks :
432+ if chunk .tool_calls and chunk .tool_calls [0 ].id :
433+ header_found = True
434+ assert (chunk .tool_calls [0 ].function .name ==
435+ expected_first_tool .function .name )
436+ assert chunk .tool_calls [0 ].type == "function"
437+ # Arguments may be empty initially or None
438+ if chunk .tool_calls [0 ].function .arguments is not None :
439+ # If present, should be empty string initially
440+ assert chunk .tool_calls [0 ].function .arguments == ""
441+ break
442+ assert header_found
443+
444+ # Should have chunks with incremental arguments
445+ arg_chunks = []
446+ for chunk in chunks :
447+ if (chunk .tool_calls and chunk .tool_calls [0 ].function .arguments
448+ and chunk .tool_calls [0 ].function .arguments != ""
449+ and chunk .tool_calls [0 ].index ==
450+ 0 # Only collect arguments from the first tool call
451+ ):
452+ arg_chunks .append (chunk .tool_calls [0 ].function .arguments )
453+
454+ # Arguments should be streamed incrementally
455+ assert len (arg_chunks ) > 1
456+
457+ # Concatenated arguments should form valid JSON for the first tool call
458+ full_args = "" .join (arg_chunks )
459+ parsed_args = json .loads (full_args )
460+ expected_args = json .loads (expected_first_tool .function .arguments )
461+ assert parsed_args == expected_args
0 commit comments