1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import re
8
+ from collections .abc import Sequence
9
+ from typing import TYPE_CHECKING
10
+
11
+ from partial_json_parser .core .options import Allow
12
+ from transformers import PreTrainedTokenizerBase
13
+
14
+ from vllm .entrypoints .chat_utils import make_tool_call_id
15
+ from vllm .entrypoints .openai .protocol import (
16
+ ChatCompletionRequest ,
17
+ DeltaFunctionCall ,
18
+ DeltaMessage ,
19
+ DeltaToolCall ,
20
+ ExtractedToolCallInformation ,
21
+ FunctionCall ,
22
+ ToolCall ,
23
+ )
24
+ from vllm .entrypoints .openai .tool_parsers .abstract_tool_parser import (
25
+ ToolParser ,
26
+ ToolParserManager ,
27
+ )
28
+ from vllm .entrypoints .openai .tool_parsers .utils import (
29
+ find_common_prefix ,
30
+ is_complete_json ,
31
+ partial_json_loads ,
32
+ )
33
+ from vllm .logger import init_logger
34
+
35
+ if TYPE_CHECKING :
36
+ pass
37
+
38
+ logger = init_logger (__name__ )
39
+
40
+
41
+ @ToolParserManager .register_module ("apertus" )
42
+ class ApertusToolParser (ToolParser ):
43
+ """
44
+ Tool call parser for Apertus models.
45
+
46
+ Extracts tool calls from the format:
47
+ <|tools_prefix|>[{"function_name": {"arg1": "value1", ...}}, ...]<|tools_suffix|>
48
+
49
+ Used when --enable-auto-tool-choice --tool-call-parser apertus are set.
50
+ """
51
+
52
+ def __init__ (self , tokenizer : PreTrainedTokenizerBase ) -> None :
53
+ super ().__init__ (tokenizer )
54
+
55
+ # Tokens for tool call delimiters
56
+ self .tool_calls_prefix = "<|tools_prefix|>"
57
+ self .tool_calls_suffix = "<|tools_suffix|>"
58
+
59
+ # State for streaming
60
+ self .prev_tool_call_arr : list [dict ] = []
61
+ self .current_tool_id : int = - 1
62
+ self .current_tool_name_sent : bool = False
63
+ self .streamed_args_for_tool : list [str ] = []
64
+
65
+ # Regex to extract tool calls block (suffix is optional for incomplete outputs)
66
+ self .tool_call_regex = re .compile (
67
+ rf"{ re .escape (self .tool_calls_prefix )} (.*?)(?:{ re .escape (self .tool_calls_suffix )} |$)" ,
68
+ re .DOTALL ,
69
+ )
70
+
71
+ def extract_tool_calls (
72
+ self , model_output : str , request : ChatCompletionRequest
73
+ ) -> ExtractedToolCallInformation :
74
+ """Extract tool calls from a complete model response."""
75
+ # Quick check before running regex
76
+ if self .tool_calls_prefix not in model_output :
77
+ return ExtractedToolCallInformation (
78
+ tools_called = False , tool_calls = [], content = model_output
79
+ )
80
+
81
+ # Find tool calls block
82
+ match = self .tool_call_regex .search (model_output )
83
+ if not match :
84
+ return ExtractedToolCallInformation (
85
+ tools_called = False , tool_calls = [], content = model_output
86
+ )
87
+
88
+ try :
89
+ json_str = match .group (1 ).strip ()
90
+ tool_call_objects = json .loads (json_str )
91
+
92
+ if not isinstance (tool_call_objects , list ):
93
+ tool_call_objects = [tool_call_objects ]
94
+
95
+ tool_calls = self ._parse_tool_call_objects (tool_call_objects )
96
+
97
+ return ExtractedToolCallInformation (
98
+ tools_called = True , tool_calls = tool_calls , content = None
99
+ )
100
+
101
+ except Exception :
102
+ logger .exception ("Error extracting tool call from response." )
103
+ return ExtractedToolCallInformation (
104
+ tools_called = False , tool_calls = [], content = model_output
105
+ )
106
+
107
+ def _parse_tool_call_objects (self , tool_call_objects : list [dict ]) -> list [ToolCall ]:
108
+ """Parse tool call objects into ToolCall instances."""
109
+ tool_calls : list [ToolCall ] = []
110
+
111
+ for obj in tool_call_objects :
112
+ # Each object is {"function_name": {"arg1": "value1", ...}}
113
+ if isinstance (obj , dict ) and len (obj ) == 1 :
114
+ function_name = next (iter (obj ))
115
+ arguments = obj [function_name ]
116
+
117
+ tool_calls .append (
118
+ ToolCall (
119
+ type = "function" ,
120
+ function = FunctionCall (
121
+ name = function_name ,
122
+ arguments = json .dumps (arguments , ensure_ascii = False ),
123
+ ),
124
+ )
125
+ )
126
+
127
+ return tool_calls
128
+
129
+ def extract_tool_calls_streaming (
130
+ self ,
131
+ previous_text : str ,
132
+ current_text : str ,
133
+ delta_text : str ,
134
+ previous_token_ids : Sequence [int ],
135
+ current_token_ids : Sequence [int ],
136
+ delta_token_ids : Sequence [int ],
137
+ request : ChatCompletionRequest ,
138
+ ) -> DeltaMessage | None :
139
+ """Extract tool calls in streaming mode."""
140
+ # Check if we're in a tool call block
141
+ if self .tool_calls_prefix not in current_text :
142
+ return DeltaMessage (content = delta_text )
143
+
144
+ json_str = self ._extract_json_str (current_text )
145
+
146
+ try :
147
+ tool_call_arr = self ._parse_partial_json (json_str )
148
+
149
+ if not tool_call_arr :
150
+ return None
151
+
152
+ # Starting a new tool in the array
153
+ if len (tool_call_arr ) > self .current_tool_id + 1 :
154
+ delta = self ._finalize_previous_tool ()
155
+ self ._start_new_tool (len (tool_call_arr ))
156
+ return delta
157
+
158
+ current_tool_call = tool_call_arr [self .current_tool_id ]
159
+
160
+ # Send tool name if not sent yet
161
+ if not self .current_tool_name_sent :
162
+ return self ._send_tool_name (current_tool_call )
163
+
164
+ # Stream arguments
165
+ delta = self ._stream_arguments (current_tool_call , json_str )
166
+ self .prev_tool_call_arr = tool_call_arr
167
+ return delta
168
+
169
+ except Exception :
170
+ logger .debug ("Error parsing streaming tool call, waiting for more tokens" )
171
+ return None
172
+
173
+ def _extract_json_str (self , current_text : str ) -> str :
174
+ """Extract JSON string from the current text."""
175
+ prefix_idx = current_text .find (self .tool_calls_prefix )
176
+ start_idx = prefix_idx + len (self .tool_calls_prefix )
177
+
178
+ # Check if suffix is present (complete tool call)
179
+ suffix_idx = current_text .find (self .tool_calls_suffix , start_idx )
180
+ if suffix_idx != - 1 :
181
+ return current_text [start_idx :suffix_idx ].strip ()
182
+ return current_text [start_idx :].strip ()
183
+
184
+ def _parse_partial_json (self , json_str : str ) -> list [dict ]:
185
+ """Parse partial JSON with appropriate flags."""
186
+ flags = Allow .ALL if self .current_tool_name_sent else Allow .ALL & ~ Allow .STR
187
+ tool_call_arr , _ = partial_json_loads (json_str , flags )
188
+
189
+ if not isinstance (tool_call_arr , list ):
190
+ tool_call_arr = [tool_call_arr ] if tool_call_arr else []
191
+
192
+ return tool_call_arr
193
+
194
+ def _finalize_previous_tool (self ) -> DeltaMessage | None :
195
+ """Finalize any remaining arguments from the previous tool."""
196
+ if self .current_tool_id < 0 :
197
+ return None
198
+
199
+ prev_tool = self .prev_tool_call_arr [self .current_tool_id ]
200
+ function_name = next (iter (prev_tool ))
201
+ arguments = prev_tool [function_name ]
202
+ args_json = json .dumps (arguments , ensure_ascii = False )
203
+ sent = len (self .streamed_args_for_tool [self .current_tool_id ])
204
+ argument_diff = args_json [sent :]
205
+
206
+ if not argument_diff :
207
+ return None
208
+
209
+ self .streamed_args_for_tool [self .current_tool_id ] += argument_diff
210
+ return DeltaMessage (
211
+ tool_calls = [
212
+ DeltaToolCall (
213
+ index = self .current_tool_id ,
214
+ function = DeltaFunctionCall (arguments = argument_diff ).model_dump (
215
+ exclude_none = True
216
+ ),
217
+ )
218
+ ]
219
+ )
220
+
221
+ def _start_new_tool (self , array_length : int ) -> None :
222
+ """Start processing a new tool."""
223
+ self .current_tool_id = array_length - 1
224
+ self .current_tool_name_sent = False
225
+ self .streamed_args_for_tool .append ("" )
226
+
227
+ def _send_tool_name (self , current_tool_call : dict ) -> DeltaMessage | None :
228
+ """Send the tool name if not sent yet."""
229
+ if not isinstance (current_tool_call , dict ) or len (current_tool_call ) != 1 :
230
+ return None
231
+
232
+ function_name = next (iter (current_tool_call ))
233
+
234
+ self .current_tool_name_sent = True
235
+ return DeltaMessage (
236
+ tool_calls = [
237
+ DeltaToolCall (
238
+ index = self .current_tool_id ,
239
+ type = "function" ,
240
+ id = make_tool_call_id (),
241
+ function = DeltaFunctionCall (name = function_name ).model_dump (
242
+ exclude_none = True
243
+ ),
244
+ )
245
+ ]
246
+ )
247
+
248
+ def _stream_arguments (
249
+ self , current_tool_call : dict , json_str : str
250
+ ) -> DeltaMessage | None :
251
+ """Stream arguments for the current tool."""
252
+ if not isinstance (current_tool_call , dict ) or len (current_tool_call ) != 1 :
253
+ return None
254
+
255
+ function_name = next (iter (current_tool_call ))
256
+ arguments = current_tool_call [function_name ]
257
+
258
+ if not arguments :
259
+ return None
260
+
261
+ sent = len (self .streamed_args_for_tool [self .current_tool_id ])
262
+ args_json = json .dumps (arguments , ensure_ascii = False )
263
+
264
+ argument_diff = self ._calculate_argument_diff (
265
+ function_name , args_json , json_str , sent
266
+ )
267
+
268
+ if not argument_diff :
269
+ return None
270
+
271
+ self .streamed_args_for_tool [self .current_tool_id ] += argument_diff
272
+ return DeltaMessage (
273
+ tool_calls = [
274
+ DeltaToolCall (
275
+ index = self .current_tool_id ,
276
+ function = DeltaFunctionCall (arguments = argument_diff ).model_dump (
277
+ exclude_none = True
278
+ ),
279
+ )
280
+ ]
281
+ )
282
+
283
+ def _calculate_argument_diff (
284
+ self , function_name : str , args_json : str , json_str : str , sent : int
285
+ ) -> str | None :
286
+ """Calculate the difference in arguments to stream."""
287
+ is_complete_call = is_complete_json (json_str )
288
+
289
+ if is_complete_call :
290
+ return args_json [sent :]
291
+
292
+ if not self .prev_tool_call_arr or self .current_tool_id >= len (
293
+ self .prev_tool_call_arr
294
+ ):
295
+ return None
296
+
297
+ prev_tool = self .prev_tool_call_arr [self .current_tool_id ]
298
+ prev_function_name = next (iter (prev_tool ))
299
+
300
+ if prev_function_name != function_name :
301
+ return None
302
+
303
+ prev_args = prev_tool [prev_function_name ]
304
+ prev_args_json = json .dumps (prev_args , ensure_ascii = False )
305
+
306
+ if args_json == prev_args_json :
307
+ return None
308
+
309
+ prefix = find_common_prefix (prev_args_json , args_json )
310
+ return prefix [sent :]
0 commit comments