@@ -57,17 +57,21 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
57
57
self .tool_calls_suffix = "<|tools_suffix|>"
58
58
59
59
# State for streaming
60
- self .prev_tool_call_arr : list [dict ] = []
61
- self .current_tool_id : int = - 1
62
- self .current_tool_name_sent : bool = False
63
- self .streamed_args_for_tool : list [str ] = []
60
+ self ._reset_streaming_state ()
64
61
65
62
# Regex to extract tool calls block (suffix is optional for incomplete outputs)
66
63
self .tool_call_regex = re .compile (
67
64
rf"{ re .escape (self .tool_calls_prefix )} (.*?)(?:{ re .escape (self .tool_calls_suffix )} |$)" ,
68
65
re .DOTALL ,
69
66
)
70
67
68
+ def _reset_streaming_state (self ):
69
+ """Reset streaming state for a new request."""
70
+ self .prev_tool_call_arr : list [dict ] = []
71
+ self .current_tool_id : int = - 1
72
+ self .current_tool_name_sent : bool = False
73
+ self .streamed_args_for_tool : list [str ] = []
74
+
71
75
def extract_tool_calls (
72
76
self , model_output : str , request : ChatCompletionRequest
73
77
) -> ExtractedToolCallInformation :
@@ -121,6 +125,7 @@ def _parse_tool_call_objects(self, tool_call_objects: list[dict]) -> list[ToolCa
121
125
name = function_name ,
122
126
arguments = json .dumps (arguments , ensure_ascii = False ),
123
127
),
128
+ id = make_tool_call_id (),
124
129
)
125
130
)
126
131
@@ -137,6 +142,12 @@ def extract_tool_calls_streaming(
137
142
request : ChatCompletionRequest ,
138
143
) -> DeltaMessage | None :
139
144
"""Extract tool calls in streaming mode."""
145
+ # Reset state at the start of a new streaming session
146
+ # (detected when previous_text is empty or doesn't contain tool prefix)
147
+ if not previous_text or (self .tool_calls_prefix not in previous_text and
148
+ self .tool_calls_prefix in current_text ):
149
+ self ._reset_streaming_state ()
150
+
140
151
# Check if we're in a tool call block
141
152
if self .tool_calls_prefix not in current_text :
142
153
return DeltaMessage (content = delta_text )
@@ -153,13 +164,16 @@ def extract_tool_calls_streaming(
153
164
if len (tool_call_arr ) > self .current_tool_id + 1 :
154
165
delta = self ._finalize_previous_tool ()
155
166
self ._start_new_tool (len (tool_call_arr ))
167
+ self .prev_tool_call_arr = tool_call_arr
156
168
return delta
157
169
158
170
current_tool_call = tool_call_arr [self .current_tool_id ]
159
171
160
172
# Send tool name if not sent yet
161
173
if not self .current_tool_name_sent :
162
- return self ._send_tool_name (current_tool_call )
174
+ delta = self ._send_tool_name (current_tool_call )
175
+ self .prev_tool_call_arr = tool_call_arr
176
+ return delta
163
177
164
178
# Stream arguments
165
179
delta = self ._stream_arguments (current_tool_call , json_str )
@@ -196,6 +210,14 @@ def _finalize_previous_tool(self) -> DeltaMessage | None:
196
210
if self .current_tool_id < 0 :
197
211
return None
198
212
213
+ # Check if prev_tool_call_arr has been initialized and has the current tool
214
+ if not self .prev_tool_call_arr or self .current_tool_id >= len (self .prev_tool_call_arr ):
215
+ return None
216
+
217
+ # Check if streamed_args_for_tool has the current tool
218
+ if self .current_tool_id >= len (self .streamed_args_for_tool ):
219
+ return None
220
+
199
221
prev_tool = self .prev_tool_call_arr [self .current_tool_id ]
200
222
function_name = next (iter (prev_tool ))
201
223
arguments = prev_tool [function_name ]
@@ -258,6 +280,10 @@ def _stream_arguments(
258
280
if not arguments :
259
281
return None
260
282
283
+ # Check if streamed_args_for_tool has the current tool
284
+ if self .current_tool_id >= len (self .streamed_args_for_tool ):
285
+ return None
286
+
261
287
sent = len (self .streamed_args_for_tool [self .current_tool_id ])
262
288
args_json = json .dumps (arguments , ensure_ascii = False )
263
289
0 commit comments