Skip to content

Commit cece2da

Browse files
committed
[Model] Add ApertusToolParser
Signed-off-by: swan.blanc <[email protected]>
1 parent fc67969 commit cece2da

File tree

2 files changed

+312
-0
lines changed

2 files changed

+312
-0
lines changed

vllm/entrypoints/openai/tool_parsers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .seed_oss_tool_parser import SeedOssToolParser
2626
from .step3_tool_parser import Step3ToolParser
2727
from .xlam_tool_parser import xLAMToolParser
28+
from .apertus_tool_parser import ApertusToolParser
2829

2930
__all__ = [
3031
"ToolParser",
@@ -52,4 +53,5 @@
5253
"SeedOssToolParser",
5354
"Step3ToolParser",
5455
"OpenAIToolParser",
56+
"ApertusToolParser",
5557
]
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from __future__ import annotations
5+
6+
import json
7+
import re
8+
from collections.abc import Sequence
9+
from typing import TYPE_CHECKING
10+
11+
from partial_json_parser.core.options import Allow
12+
from transformers import PreTrainedTokenizerBase
13+
14+
from vllm.entrypoints.chat_utils import make_tool_call_id
15+
from vllm.entrypoints.openai.protocol import (
16+
ChatCompletionRequest,
17+
DeltaFunctionCall,
18+
DeltaMessage,
19+
DeltaToolCall,
20+
ExtractedToolCallInformation,
21+
FunctionCall,
22+
ToolCall,
23+
)
24+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
25+
ToolParser,
26+
ToolParserManager,
27+
)
28+
from vllm.entrypoints.openai.tool_parsers.utils import (
29+
find_common_prefix,
30+
is_complete_json,
31+
partial_json_loads,
32+
)
33+
from vllm.logger import init_logger
34+
35+
if TYPE_CHECKING:
36+
pass
37+
38+
logger = init_logger(__name__)
39+
40+
41+
@ToolParserManager.register_module("apertus")
42+
class ApertusToolParser(ToolParser):
43+
"""
44+
Tool call parser for Apertus models.
45+
46+
Extracts tool calls from the format:
47+
<|tools_prefix|>[{"function_name": {"arg1": "value1", ...}}, ...]<|tools_suffix|>
48+
49+
Used when --enable-auto-tool-choice --tool-call-parser apertus are set.
50+
"""
51+
52+
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
53+
super().__init__(tokenizer)
54+
55+
# Tokens for tool call delimiters
56+
self.tool_calls_prefix = "<|tools_prefix|>"
57+
self.tool_calls_suffix = "<|tools_suffix|>"
58+
59+
# State for streaming
60+
self.prev_tool_call_arr: list[dict] = []
61+
self.current_tool_id: int = -1
62+
self.current_tool_name_sent: bool = False
63+
self.streamed_args_for_tool: list[str] = []
64+
65+
# Regex to extract tool calls block (suffix is optional for incomplete outputs)
66+
self.tool_call_regex = re.compile(
67+
rf"{re.escape(self.tool_calls_prefix)}(.*?)(?:{re.escape(self.tool_calls_suffix)}|$)",
68+
re.DOTALL,
69+
)
70+
71+
def extract_tool_calls(
72+
self, model_output: str, request: ChatCompletionRequest
73+
) -> ExtractedToolCallInformation:
74+
"""Extract tool calls from a complete model response."""
75+
# Quick check before running regex
76+
if self.tool_calls_prefix not in model_output:
77+
return ExtractedToolCallInformation(
78+
tools_called=False, tool_calls=[], content=model_output
79+
)
80+
81+
# Find tool calls block
82+
match = self.tool_call_regex.search(model_output)
83+
if not match:
84+
return ExtractedToolCallInformation(
85+
tools_called=False, tool_calls=[], content=model_output
86+
)
87+
88+
try:
89+
json_str = match.group(1).strip()
90+
tool_call_objects = json.loads(json_str)
91+
92+
if not isinstance(tool_call_objects, list):
93+
tool_call_objects = [tool_call_objects]
94+
95+
tool_calls = self._parse_tool_call_objects(tool_call_objects)
96+
97+
return ExtractedToolCallInformation(
98+
tools_called=True, tool_calls=tool_calls, content=None
99+
)
100+
101+
except Exception:
102+
logger.exception("Error extracting tool call from response.")
103+
return ExtractedToolCallInformation(
104+
tools_called=False, tool_calls=[], content=model_output
105+
)
106+
107+
def _parse_tool_call_objects(self, tool_call_objects: list[dict]) -> list[ToolCall]:
108+
"""Parse tool call objects into ToolCall instances."""
109+
tool_calls: list[ToolCall] = []
110+
111+
for obj in tool_call_objects:
112+
# Each object is {"function_name": {"arg1": "value1", ...}}
113+
if isinstance(obj, dict) and len(obj) == 1:
114+
function_name = next(iter(obj))
115+
arguments = obj[function_name]
116+
117+
tool_calls.append(
118+
ToolCall(
119+
type="function",
120+
function=FunctionCall(
121+
name=function_name,
122+
arguments=json.dumps(arguments, ensure_ascii=False),
123+
),
124+
)
125+
)
126+
127+
return tool_calls
128+
129+
def extract_tool_calls_streaming(
130+
self,
131+
previous_text: str,
132+
current_text: str,
133+
delta_text: str,
134+
previous_token_ids: Sequence[int],
135+
current_token_ids: Sequence[int],
136+
delta_token_ids: Sequence[int],
137+
request: ChatCompletionRequest,
138+
) -> DeltaMessage | None:
139+
"""Extract tool calls in streaming mode."""
140+
# Check if we're in a tool call block
141+
if self.tool_calls_prefix not in current_text:
142+
return DeltaMessage(content=delta_text)
143+
144+
json_str = self._extract_json_str(current_text)
145+
146+
try:
147+
tool_call_arr = self._parse_partial_json(json_str)
148+
149+
if not tool_call_arr:
150+
return None
151+
152+
# Starting a new tool in the array
153+
if len(tool_call_arr) > self.current_tool_id + 1:
154+
delta = self._finalize_previous_tool()
155+
self._start_new_tool(len(tool_call_arr))
156+
return delta
157+
158+
current_tool_call = tool_call_arr[self.current_tool_id]
159+
160+
# Send tool name if not sent yet
161+
if not self.current_tool_name_sent:
162+
return self._send_tool_name(current_tool_call)
163+
164+
# Stream arguments
165+
delta = self._stream_arguments(current_tool_call, json_str)
166+
self.prev_tool_call_arr = tool_call_arr
167+
return delta
168+
169+
except Exception:
170+
logger.debug("Error parsing streaming tool call, waiting for more tokens")
171+
return None
172+
173+
def _extract_json_str(self, current_text: str) -> str:
174+
"""Extract JSON string from the current text."""
175+
prefix_idx = current_text.find(self.tool_calls_prefix)
176+
start_idx = prefix_idx + len(self.tool_calls_prefix)
177+
178+
# Check if suffix is present (complete tool call)
179+
suffix_idx = current_text.find(self.tool_calls_suffix, start_idx)
180+
if suffix_idx != -1:
181+
return current_text[start_idx:suffix_idx].strip()
182+
return current_text[start_idx:].strip()
183+
184+
def _parse_partial_json(self, json_str: str) -> list[dict]:
185+
"""Parse partial JSON with appropriate flags."""
186+
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
187+
tool_call_arr, _ = partial_json_loads(json_str, flags)
188+
189+
if not isinstance(tool_call_arr, list):
190+
tool_call_arr = [tool_call_arr] if tool_call_arr else []
191+
192+
return tool_call_arr
193+
194+
def _finalize_previous_tool(self) -> DeltaMessage | None:
195+
"""Finalize any remaining arguments from the previous tool."""
196+
if self.current_tool_id < 0:
197+
return None
198+
199+
prev_tool = self.prev_tool_call_arr[self.current_tool_id]
200+
function_name = next(iter(prev_tool))
201+
arguments = prev_tool[function_name]
202+
args_json = json.dumps(arguments, ensure_ascii=False)
203+
sent = len(self.streamed_args_for_tool[self.current_tool_id])
204+
argument_diff = args_json[sent:]
205+
206+
if not argument_diff:
207+
return None
208+
209+
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
210+
return DeltaMessage(
211+
tool_calls=[
212+
DeltaToolCall(
213+
index=self.current_tool_id,
214+
function=DeltaFunctionCall(arguments=argument_diff).model_dump(
215+
exclude_none=True
216+
),
217+
)
218+
]
219+
)
220+
221+
def _start_new_tool(self, array_length: int) -> None:
222+
"""Start processing a new tool."""
223+
self.current_tool_id = array_length - 1
224+
self.current_tool_name_sent = False
225+
self.streamed_args_for_tool.append("")
226+
227+
def _send_tool_name(self, current_tool_call: dict) -> DeltaMessage | None:
228+
"""Send the tool name if not sent yet."""
229+
if not isinstance(current_tool_call, dict) or len(current_tool_call) != 1:
230+
return None
231+
232+
function_name = next(iter(current_tool_call))
233+
234+
self.current_tool_name_sent = True
235+
return DeltaMessage(
236+
tool_calls=[
237+
DeltaToolCall(
238+
index=self.current_tool_id,
239+
type="function",
240+
id=make_tool_call_id(),
241+
function=DeltaFunctionCall(name=function_name).model_dump(
242+
exclude_none=True
243+
),
244+
)
245+
]
246+
)
247+
248+
def _stream_arguments(
249+
self, current_tool_call: dict, json_str: str
250+
) -> DeltaMessage | None:
251+
"""Stream arguments for the current tool."""
252+
if not isinstance(current_tool_call, dict) or len(current_tool_call) != 1:
253+
return None
254+
255+
function_name = next(iter(current_tool_call))
256+
arguments = current_tool_call[function_name]
257+
258+
if not arguments:
259+
return None
260+
261+
sent = len(self.streamed_args_for_tool[self.current_tool_id])
262+
args_json = json.dumps(arguments, ensure_ascii=False)
263+
264+
argument_diff = self._calculate_argument_diff(
265+
function_name, args_json, json_str, sent
266+
)
267+
268+
if not argument_diff:
269+
return None
270+
271+
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
272+
return DeltaMessage(
273+
tool_calls=[
274+
DeltaToolCall(
275+
index=self.current_tool_id,
276+
function=DeltaFunctionCall(arguments=argument_diff).model_dump(
277+
exclude_none=True
278+
),
279+
)
280+
]
281+
)
282+
283+
def _calculate_argument_diff(
284+
self, function_name: str, args_json: str, json_str: str, sent: int
285+
) -> str | None:
286+
"""Calculate the difference in arguments to stream."""
287+
is_complete_call = is_complete_json(json_str)
288+
289+
if is_complete_call:
290+
return args_json[sent:]
291+
292+
if not self.prev_tool_call_arr or self.current_tool_id >= len(
293+
self.prev_tool_call_arr
294+
):
295+
return None
296+
297+
prev_tool = self.prev_tool_call_arr[self.current_tool_id]
298+
prev_function_name = next(iter(prev_tool))
299+
300+
if prev_function_name != function_name:
301+
return None
302+
303+
prev_args = prev_tool[prev_function_name]
304+
prev_args_json = json.dumps(prev_args, ensure_ascii=False)
305+
306+
if args_json == prev_args_json:
307+
return None
308+
309+
prefix = find_common_prefix(prev_args_json, args_json)
310+
return prefix[sent:]

0 commit comments

Comments
 (0)