Skip to content

Commit 9cda30f

Browse files
chaunceyjiangYuqi Zhang
authored andcommitted
[Feature][V1] Support tool_choice: required when using Xgrammar as the StructuredOutputBackend. (vllm-project#17845)
Signed-off-by: chaunceyjiang <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
1 parent fa7e846 commit 9cda30f

File tree

5 files changed

+160
-13
lines changed

5 files changed

+160
-13
lines changed

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ lm-format-enforcer >= 0.10.11, < 0.11
2222
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
2323
outlines == 0.1.11
2424
lark == 1.2.2
25-
xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64"
25+
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64"
2626
typing_extensions >= 4.10
2727
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
2828
partial-json-parser # used for parsing partial JSON outputs
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import openai # use the official client for correctness check
4+
import pytest
5+
import pytest_asyncio
6+
7+
# downloading lora to test lora requests
8+
from ...utils import RemoteOpenAIServer
9+
10+
# any model with a chat template should work here
11+
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
12+
13+
14+
@pytest.fixture(scope="module")
15+
def server(): # noqa: F811
16+
args = [
17+
# use half precision for speed and memory savings in CI environment
18+
"--dtype",
19+
"half",
20+
"--enable-auto-tool-choice",
21+
"--guided-decoding-backend",
22+
"xgrammar",
23+
"--tool-call-parser",
24+
"hermes"
25+
]
26+
27+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
28+
yield remote_server
29+
30+
31+
@pytest_asyncio.fixture
32+
async def client(server):
33+
async with server.get_async_client() as async_client:
34+
yield async_client
35+
36+
37+
@pytest.mark.asyncio
38+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
39+
async def test_required_tool_use(client: openai.AsyncOpenAI, model_name: str):
40+
tools = [
41+
{
42+
"type": "function",
43+
"function": {
44+
"name": "get_current_weather",
45+
"description": "Get the current weather in a given location",
46+
"parameters": {
47+
"type": "object",
48+
"properties": {
49+
"city": {
50+
"type": "string",
51+
"description":
52+
"The city to find the weather for, e.g. 'Vienna'",
53+
"default": "Vienna",
54+
},
55+
"country": {
56+
"type":
57+
"string",
58+
"description":
59+
"The country that the city is in, e.g. 'Austria'",
60+
},
61+
"unit": {
62+
"type": "string",
63+
"description":
64+
"The unit to fetch the temperature in",
65+
"enum": ["celsius", "fahrenheit"],
66+
},
67+
},
68+
"required": ["country", "unit"],
69+
},
70+
},
71+
},
72+
{
73+
"type": "function",
74+
"function": {
75+
"name": "get_forecast",
76+
"description": "Get the weather forecast for a given location",
77+
"parameters": {
78+
"type": "object",
79+
"properties": {
80+
"city": {
81+
"type": "string",
82+
"description":
83+
"The city to get the forecast for, e.g. 'Vienna'",
84+
"default": "Vienna",
85+
},
86+
"country": {
87+
"type":
88+
"string",
89+
"description":
90+
"The country that the city is in, e.g. 'Austria'",
91+
},
92+
"days": {
93+
"type":
94+
"integer",
95+
"description":
96+
"Number of days to get the forecast for (1-7)",
97+
},
98+
"unit": {
99+
"type": "string",
100+
"description":
101+
"The unit to fetch the temperature in",
102+
"enum": ["celsius", "fahrenheit"],
103+
},
104+
},
105+
"required": ["country", "days", "unit"],
106+
},
107+
},
108+
},
109+
]
110+
111+
messages = [
112+
{
113+
"role": "user",
114+
"content": "Hi! How are you doing today?"
115+
},
116+
{
117+
"role": "assistant",
118+
"content": "I'm doing well! How can I help you?"
119+
},
120+
{
121+
"role":
122+
"user",
123+
"content":
124+
"Can you tell me what the current weather is in Berlin and the "\
125+
"forecast for the next 5 days, in fahrenheit?",
126+
},
127+
]
128+
129+
# Non-streaming test
130+
chat_completion = await client.chat.completions.create(
131+
messages=messages,
132+
model=model_name,
133+
tools=tools,
134+
tool_choice="required",
135+
)
136+
137+
assert chat_completion.choices[0].message.tool_calls is not None
138+
assert len(chat_completion.choices[0].message.tool_calls) > 0
139+
140+
# Streaming test
141+
stream = await client.chat.completions.create(
142+
messages=messages,
143+
model=model_name,
144+
tools=tools,
145+
tool_choice="required",
146+
stream=True,
147+
)
148+
149+
output = []
150+
async for chunk in stream:
151+
if chunk.choices and chunk.choices[0].delta.tool_calls:
152+
output.extend(chunk.choices[0].delta.tool_calls)
153+
154+
assert len(output) > 0

tests/v1/entrypoints/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def sample_json_schema():
7474
},
7575
"required": ["company", "duration", "position"],
7676
"additionalProperties": False
77-
}
77+
},
78+
"minItems": 0,
79+
"maxItems": 3
7880
}
7981
},
8082
"required":

tests/v1/structured_output/test_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,6 @@ def unsupported_array_schemas():
5757
"type": "array",
5858
"maxContains": 5
5959
},
60-
{
61-
"type": "array",
62-
"minItems": 1
63-
},
64-
{
65-
"type": "array",
66-
"maxItems": 10
67-
},
6860
]
6961

7062

vllm/v1/structured_output/backend_xgrammar.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,8 @@ def check_object(obj: dict[str, Any]) -> bool:
215215

216216
# Check for array unsupported keywords
217217
if obj.get("type") == "array" and any(
218-
key in obj
219-
for key in ("uniqueItems", "contains", "minContains",
220-
"maxContains", "minItems", "maxItems")):
218+
key in obj for key in ("uniqueItems", "contains",
219+
"minContains", "maxContains")):
221220
return True
222221

223222
# Unsupported keywords for strings

0 commit comments

Comments
 (0)