Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .seed_oss_tool_parser import SeedOssToolParser
from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
from .apertus_tool_parser import ApertusToolParser

__all__ = [
"ToolParser",
Expand Down Expand Up @@ -52,4 +53,5 @@
"SeedOssToolParser",
"Step3ToolParser",
"OpenAIToolParser",
"ApertusToolParser",
]
336 changes: 336 additions & 0 deletions vllm/entrypoints/openai/tool_parsers/apertus_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

import json
import re
from collections.abc import Sequence
from typing import TYPE_CHECKING

from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from vllm.entrypoints.openai.tool_parsers.utils import (
find_common_prefix,
is_complete_json,
partial_json_loads,
)
from vllm.logger import init_logger

if TYPE_CHECKING:
pass

logger = init_logger(__name__)


@ToolParserManager.register_module("apertus")
class ApertusToolParser(ToolParser):
"""
Tool call parser for Apertus models.

Extracts tool calls from the format:
<|tools_prefix|>[{"function_name": {"arg1": "value1", ...}}, ...]<|tools_suffix|>

Used when --enable-auto-tool-choice --tool-call-parser apertus are set.
"""

def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
super().__init__(tokenizer)

# Tokens for tool call delimiters
self.tool_calls_prefix = "<|tools_prefix|>"
self.tool_calls_suffix = "<|tools_suffix|>"

# State for streaming
self._reset_streaming_state()

# Regex to extract tool calls block (suffix is optional for incomplete outputs)
self.tool_call_regex = re.compile(
rf"{re.escape(self.tool_calls_prefix)}(.*?)(?:{re.escape(self.tool_calls_suffix)}|$)",
re.DOTALL,
)

def _reset_streaming_state(self):
"""Reset streaming state for a new request."""
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = []

def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""Extract tool calls from a complete model response."""
# Quick check before running regex
if self.tool_calls_prefix not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)

# Find tool calls block
match = self.tool_call_regex.search(model_output)
if not match:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)

try:
json_str = match.group(1).strip()
tool_call_objects = json.loads(json_str)

if not isinstance(tool_call_objects, list):
tool_call_objects = [tool_call_objects]

tool_calls = self._parse_tool_call_objects(tool_call_objects)

return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=None
)

except Exception:
logger.exception("Error extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)

def _parse_tool_call_objects(self, tool_call_objects: list[dict]) -> list[ToolCall]:
"""Parse tool call objects into ToolCall instances."""
tool_calls: list[ToolCall] = []

for obj in tool_call_objects:
# Each object is {"function_name": {"arg1": "value1", ...}}
if isinstance(obj, dict) and len(obj) == 1:
function_name = next(iter(obj))
arguments = obj[function_name]

tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=function_name,
arguments=json.dumps(arguments, ensure_ascii=False),

Check failure on line 126 in vllm/entrypoints/openai/tool_parsers/apertus_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/openai/tool_parsers/apertus_tool_parser.py:126:89: E501 Line too long (92 > 88)
),
id=make_tool_call_id(),
)
)

return tool_calls

def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls in streaming mode."""
# Reset state at the start of a new streaming session
# (detected when previous_text is empty or doesn't contain tool prefix)
if not previous_text or (self.tool_calls_prefix not in previous_text and
self.tool_calls_prefix in current_text):
self._reset_streaming_state()

# Check if we're in a tool call block
if self.tool_calls_prefix not in current_text:
return DeltaMessage(content=delta_text)

json_str = self._extract_json_str(current_text)

try:
tool_call_arr = self._parse_partial_json(json_str)

if not tool_call_arr:
return None

# Starting a new tool in the array
if len(tool_call_arr) > self.current_tool_id + 1:
delta = self._finalize_previous_tool()
self._start_new_tool(len(tool_call_arr))
self.prev_tool_call_arr = tool_call_arr
return delta

current_tool_call = tool_call_arr[self.current_tool_id]

# Send tool name if not sent yet
if not self.current_tool_name_sent:
delta = self._send_tool_name(current_tool_call)
self.prev_tool_call_arr = tool_call_arr
return delta

# Stream arguments
delta = self._stream_arguments(current_tool_call, json_str)
self.prev_tool_call_arr = tool_call_arr
return delta

except Exception:
logger.debug("Error parsing streaming tool call, waiting for more tokens")
return None

def _extract_json_str(self, current_text: str) -> str:
"""Extract JSON string from the current text."""
prefix_idx = current_text.find(self.tool_calls_prefix)
start_idx = prefix_idx + len(self.tool_calls_prefix)

# Check if suffix is present (complete tool call)
suffix_idx = current_text.find(self.tool_calls_suffix, start_idx)
if suffix_idx != -1:
return current_text[start_idx:suffix_idx].strip()
return current_text[start_idx:].strip()

def _parse_partial_json(self, json_str: str) -> list[dict]:
"""Parse partial JSON with appropriate flags."""
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
tool_call_arr, _ = partial_json_loads(json_str, flags)

if not isinstance(tool_call_arr, list):
tool_call_arr = [tool_call_arr] if tool_call_arr else []

return tool_call_arr

def _finalize_previous_tool(self) -> DeltaMessage | None:
"""Finalize any remaining arguments from the previous tool."""
if self.current_tool_id < 0:
return None

# Check if prev_tool_call_arr has been initialized and has the current tool
if not self.prev_tool_call_arr or self.current_tool_id >= len(self.prev_tool_call_arr):

Check failure on line 214 in vllm/entrypoints/openai/tool_parsers/apertus_tool_parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/openai/tool_parsers/apertus_tool_parser.py:214:89: E501 Line too long (95 > 88)
return None

# Check if streamed_args_for_tool has the current tool
if self.current_tool_id >= len(self.streamed_args_for_tool):
return None

prev_tool = self.prev_tool_call_arr[self.current_tool_id]
function_name = next(iter(prev_tool))
arguments = prev_tool[function_name]
args_json = json.dumps(arguments, ensure_ascii=False)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = args_json[sent:]

if not argument_diff:
return None

self.streamed_args_for_tool[self.current_tool_id] += argument_diff
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=argument_diff).model_dump(
exclude_none=True
),
)
]
)

def _start_new_tool(self, array_length: int) -> None:
"""Start processing a new tool."""
self.current_tool_id = array_length - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")

def _send_tool_name(self, current_tool_call: dict) -> DeltaMessage | None:
"""Send the tool name if not sent yet."""
if not isinstance(current_tool_call, dict) or len(current_tool_call) != 1:
return None

function_name = next(iter(current_tool_call))

self.current_tool_name_sent = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(name=function_name).model_dump(
exclude_none=True
),
)
]
)

def _stream_arguments(
self, current_tool_call: dict, json_str: str
) -> DeltaMessage | None:
"""Stream arguments for the current tool."""
if not isinstance(current_tool_call, dict) or len(current_tool_call) != 1:
return None

function_name = next(iter(current_tool_call))
arguments = current_tool_call[function_name]

if not arguments:
return None

# Check if streamed_args_for_tool has the current tool
if self.current_tool_id >= len(self.streamed_args_for_tool):
return None

sent = len(self.streamed_args_for_tool[self.current_tool_id])
args_json = json.dumps(arguments, ensure_ascii=False)

argument_diff = self._calculate_argument_diff(
function_name, args_json, json_str, sent
)

if not argument_diff:
return None

self.streamed_args_for_tool[self.current_tool_id] += argument_diff
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=argument_diff).model_dump(
exclude_none=True
),
)
]
)

def _calculate_argument_diff(
self, function_name: str, args_json: str, json_str: str, sent: int
) -> str | None:
"""Calculate the difference in arguments to stream."""
is_complete_call = is_complete_json(json_str)

if is_complete_call:
return args_json[sent:]

if not self.prev_tool_call_arr or self.current_tool_id >= len(
self.prev_tool_call_arr
):
return None

prev_tool = self.prev_tool_call_arr[self.current_tool_id]
prev_function_name = next(iter(prev_tool))

if prev_function_name != function_name:
return None

prev_args = prev_tool[prev_function_name]
prev_args_json = json.dumps(prev_args, ensure_ascii=False)

if args_json == prev_args_json:
return None

prefix = find_common_prefix(prev_args_json, args_json)
return prefix[sent:]
Loading