Skip to content

Commit aac4953

Browse files
committed
Fix qa review
1 parent 6fa9dd1 commit aac4953

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

vllm/entrypoints/openai/tool_parsers/apertus_tool_parser.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,21 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
5757
self.tool_calls_suffix = "<|tools_suffix|>"
5858

5959
# State for streaming
60-
self.prev_tool_call_arr: list[dict] = []
61-
self.current_tool_id: int = -1
62-
self.current_tool_name_sent: bool = False
63-
self.streamed_args_for_tool: list[str] = []
60+
self._reset_streaming_state()
6461

6562
# Regex to extract tool calls block (suffix is optional for incomplete outputs)
6663
self.tool_call_regex = re.compile(
6764
rf"{re.escape(self.tool_calls_prefix)}(.*?)(?:{re.escape(self.tool_calls_suffix)}|$)",
6865
re.DOTALL,
6966
)
7067

68+
def _reset_streaming_state(self):
69+
"""Reset streaming state for a new request."""
70+
self.prev_tool_call_arr: list[dict] = []
71+
self.current_tool_id: int = -1
72+
self.current_tool_name_sent: bool = False
73+
self.streamed_args_for_tool: list[str] = []
74+
7175
def extract_tool_calls(
7276
self, model_output: str, request: ChatCompletionRequest
7377
) -> ExtractedToolCallInformation:
@@ -121,6 +125,7 @@ def _parse_tool_call_objects(self, tool_call_objects: list[dict]) -> list[ToolCa
121125
name=function_name,
122126
arguments=json.dumps(arguments, ensure_ascii=False),
123127
),
128+
id=make_tool_call_id(),
124129
)
125130
)
126131

@@ -137,6 +142,12 @@ def extract_tool_calls_streaming(
137142
request: ChatCompletionRequest,
138143
) -> DeltaMessage | None:
139144
"""Extract tool calls in streaming mode."""
145+
# Reset state at the start of a new streaming session
146+
# (detected when previous_text is empty or doesn't contain tool prefix)
147+
if not previous_text or (self.tool_calls_prefix not in previous_text and
148+
self.tool_calls_prefix in current_text):
149+
self._reset_streaming_state()
150+
140151
# Check if we're in a tool call block
141152
if self.tool_calls_prefix not in current_text:
142153
return DeltaMessage(content=delta_text)
@@ -153,13 +164,16 @@ def extract_tool_calls_streaming(
153164
if len(tool_call_arr) > self.current_tool_id + 1:
154165
delta = self._finalize_previous_tool()
155166
self._start_new_tool(len(tool_call_arr))
167+
self.prev_tool_call_arr = tool_call_arr
156168
return delta
157169

158170
current_tool_call = tool_call_arr[self.current_tool_id]
159171

160172
# Send tool name if not sent yet
161173
if not self.current_tool_name_sent:
162-
return self._send_tool_name(current_tool_call)
174+
delta = self._send_tool_name(current_tool_call)
175+
self.prev_tool_call_arr = tool_call_arr
176+
return delta
163177

164178
# Stream arguments
165179
delta = self._stream_arguments(current_tool_call, json_str)
@@ -196,6 +210,14 @@ def _finalize_previous_tool(self) -> DeltaMessage | None:
196210
if self.current_tool_id < 0:
197211
return None
198212

213+
# Check if prev_tool_call_arr has been initialized and has the current tool
214+
if not self.prev_tool_call_arr or self.current_tool_id >= len(self.prev_tool_call_arr):
215+
return None
216+
217+
# Check if streamed_args_for_tool has the current tool
218+
if self.current_tool_id >= len(self.streamed_args_for_tool):
219+
return None
220+
199221
prev_tool = self.prev_tool_call_arr[self.current_tool_id]
200222
function_name = next(iter(prev_tool))
201223
arguments = prev_tool[function_name]
@@ -258,6 +280,10 @@ def _stream_arguments(
258280
if not arguments:
259281
return None
260282

283+
# Check if streamed_args_for_tool has the current tool
284+
if self.current_tool_id >= len(self.streamed_args_for_tool):
285+
return None
286+
261287
sent = len(self.streamed_args_for_tool[self.current_tool_id])
262288
args_json = json.dumps(arguments, ensure_ascii=False)
263289

0 commit comments

Comments
 (0)