Skip to content

Commit 0eecb31

Browse files
david6666666BruceW-07chaunceyjiang
authored
[Bugfix] Fix hermes tool parser handling of non-string argument types (#22002)
Signed-off-by: wangzi <[email protected]> Signed-off-by: David Chen <[email protected]> Co-authored-by: wangzi <[email protected]> Co-authored-by: Chauncey <[email protected]>
1 parent 793be8d commit 0eecb31

File tree

2 files changed

+166
-7
lines changed

2 files changed

+166
-7
lines changed

tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,39 @@
4545
},
4646
}]
4747

48+
PRODUCT_TOOLS = [{
49+
"type": "function",
50+
"function": {
51+
"name": "get_product_info",
52+
"description": "Get detailed information of a product based on its "
53+
"product ID.",
54+
"parameters": {
55+
"type": "object",
56+
"properties": {
57+
"inserted": {
58+
"type": "boolean",
59+
"description": "inserted.",
60+
},
61+
"product_id": {
62+
"type": "integer",
63+
"description": "The product ID of the product.",
64+
},
65+
},
66+
"required": ["product_id", "inserted"],
67+
},
68+
},
69+
}]
70+
4871
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
4972

73+
PRODUCT_MESSAGES = [{
74+
"role":
75+
"user",
76+
"content":
77+
"Hi! Do you have any detailed information about the product id "
78+
"7355608 and inserted true?"
79+
}]
80+
5081

5182
@pytest.mark.asyncio
5283
async def test_non_streaming_tool_call():
@@ -127,3 +158,103 @@ async def test_streaming_tool_call():
127158
print("\n[Streaming Test Passed]")
128159
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
129160
print(f"Reconstructed Arguments: {arguments}")
161+
162+
163+
@pytest.mark.asyncio
164+
async def test_non_streaming_product_tool_call():
165+
"""Test tool call integer and boolean parameters in non-streaming mode."""
166+
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
167+
client = server.get_async_client()
168+
169+
response = await client.chat.completions.create(
170+
model=LORA_MODEL,
171+
messages=PRODUCT_MESSAGES,
172+
tools=PRODUCT_TOOLS,
173+
tool_choice="auto",
174+
temperature=0.66,
175+
)
176+
177+
assert response.choices
178+
choice = response.choices[0]
179+
message = choice.message
180+
181+
assert choice.finish_reason == "tool_calls"
182+
assert message.tool_calls is not None
183+
184+
tool_call = message.tool_calls[0]
185+
assert tool_call.type == "function"
186+
assert tool_call.function.name == "get_product_info"
187+
188+
arguments = json.loads(tool_call.function.arguments)
189+
assert "product_id" in arguments
190+
assert "inserted" in arguments
191+
192+
product_id = arguments.get("product_id")
193+
inserted = arguments.get("inserted")
194+
195+
assert isinstance(product_id, int)
196+
assert product_id == 7355608
197+
assert isinstance(inserted, bool)
198+
assert inserted is True
199+
200+
print("\n[Non-Streaming Product Test Passed]")
201+
print(f"Tool Call: {tool_call.function.name}")
202+
print(f"Arguments: {arguments}")
203+
204+
205+
@pytest.mark.asyncio
206+
async def test_streaming_product_tool_call():
207+
"""Test tool call integer and boolean parameters in streaming mode."""
208+
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
209+
client = server.get_async_client()
210+
211+
stream = await client.chat.completions.create(
212+
model=LORA_MODEL,
213+
messages=PRODUCT_MESSAGES,
214+
tools=PRODUCT_TOOLS,
215+
tool_choice="auto",
216+
temperature=0.66,
217+
stream=True,
218+
)
219+
220+
tool_call_chunks = {}
221+
async for chunk in stream:
222+
if not chunk.choices:
223+
continue
224+
225+
delta = chunk.choices[0].delta
226+
if not delta or not delta.tool_calls:
227+
continue
228+
229+
for tool_chunk in delta.tool_calls:
230+
index = tool_chunk.index
231+
if index not in tool_call_chunks:
232+
tool_call_chunks[index] = {"name": "", "arguments": ""}
233+
234+
if tool_chunk.function.name:
235+
tool_call_chunks[index]["name"] += tool_chunk.function.name
236+
if tool_chunk.function.arguments:
237+
tool_call_chunks[index][
238+
"arguments"] += tool_chunk.function.arguments
239+
240+
assert len(tool_call_chunks) == 1
241+
reconstructed_tool_call = tool_call_chunks[0]
242+
243+
assert reconstructed_tool_call["name"] == "get_product_info"
244+
245+
arguments = json.loads(reconstructed_tool_call["arguments"])
246+
assert "product_id" in arguments
247+
assert "inserted" in arguments
248+
249+
# Handle type coercion for streaming test as well
250+
product_id = arguments.get("product_id")
251+
inserted = arguments.get("inserted")
252+
253+
assert isinstance(product_id, int)
254+
assert product_id == 7355608
255+
assert isinstance(inserted, bool)
256+
assert inserted is True
257+
258+
print("\n[Streaming Product Test Passed]")
259+
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
260+
print(f"Reconstructed Arguments: {arguments}")

vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,16 +368,32 @@ def extract_tool_calls_streaming(
368368
# case -- we now have the first info about arguments available from
369369
# autocompleting the JSON
370370
elif cur_arguments and not prev_arguments:
371+
# extract the content after {"name": ..., "arguments":
372+
# directly from tool_call_portion as cur_arguments_json,
373+
# since cur_arguments may differ from the original text
374+
# due to partial JSON parsing
375+
# for example, tool_call_portion =
376+
# {"name": "search", "arguments": {"search_request": {"
377+
# but cur_arguments =
378+
# {"search_request": {}}
379+
function_name = current_tool_call.get("name")
380+
match = re.search(
381+
r'\{"name":\s*"' +
382+
re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
383+
tool_call_portion.strip(), re.DOTALL)
384+
if match:
385+
cur_arguments_json = match.group(1)
386+
else:
387+
cur_arguments_json = json.dumps(cur_arguments,
388+
ensure_ascii=False)
371389

372-
cur_arguments_json = json.dumps(cur_arguments,
373-
ensure_ascii=False)
374390
logger.debug("finding %s in %s", delta_text,
375391
cur_arguments_json)
376392

377-
# get the location where previous args differ from current
378-
if (delta_text not in cur_arguments_json[:-2]):
393+
# get the location where previous args differ from current.
394+
if (delta_text not in cur_arguments_json):
379395
return None
380-
args_delta_start_loc = cur_arguments_json[:-2]. \
396+
args_delta_start_loc = cur_arguments_json. \
381397
rindex(delta_text) + \
382398
len(delta_text)
383399

@@ -397,8 +413,20 @@ def extract_tool_calls_streaming(
397413

398414
# last case -- we have an update to existing arguments.
399415
elif cur_arguments and prev_arguments:
400-
if isinstance(delta_text, str) and len(delta_text.rstrip(
401-
)) >= 1 and delta_text.rstrip()[-1] == '}':
416+
# judge whether the tool_call_portion is a complete JSON
417+
try:
418+
json.loads(tool_call_portion)
419+
is_complete_json = True
420+
except Exception:
421+
is_complete_json = False
422+
423+
# if the delta_text ends with a '}' and tool_call_portion is a
424+
# complete JSON, then the last '}' does not belong to the
425+
# arguments, so we should trim it off
426+
if isinstance(delta_text, str) \
427+
and len(delta_text.rstrip()) >= 1 \
428+
and delta_text.rstrip()[-1] == '}' \
429+
and is_complete_json:
402430
delta_text = delta_text.rstrip()[:-1]
403431

404432
logger.debug("got diff %s", delta_text)

0 commit comments

Comments
 (0)