Skip to content

Commit 3ad1d7b

Browse files
committed
Increment tool index when new call follows ongoing call
Signed-off-by: Aleksandr Samarin <[email protected]>
1 parent c2dee6f commit 3ad1d7b

File tree

1 file changed

+87
-52
lines changed

1 file changed

+87
-52
lines changed

vllm/entrypoints/openai/serving_chat.py

Lines changed: 87 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -754,12 +754,14 @@ async def chat_completion_stream_generator(
754754
for token_id in output.token_ids:
755755
harmony_parser.process(token_id)
756756
token_delta = harmony_parser.last_content_delta or ""
757-
token_states.append((
758-
harmony_parser.current_channel,
759-
harmony_parser.current_recipient,
760-
token_delta
761-
))
762-
delta_text = "".join(state for _, _, state in token_states)
757+
token_states.append(
758+
(
759+
harmony_parser.current_channel,
760+
harmony_parser.current_recipient,
761+
token_delta,
762+
)
763+
)
764+
delta_text = "".join(delta for _, _, delta in token_states)
763765
else:
764766
delta_text = output.text
765767

@@ -792,15 +794,20 @@ async def chat_completion_stream_generator(
792794
# Group consecutive tokens with same channel/recipient
793795
groups: list[dict[str, str]] = []
794796
for channel, recipient, text in token_states:
795-
if (groups and groups[-1]['channel'] == channel
796-
and groups[-1]['recipient'] == recipient):
797-
groups[-1]['text'] += text
797+
if (
798+
groups
799+
and groups[-1]["channel"] == channel
800+
and groups[-1]["recipient"] == recipient
801+
):
802+
groups[-1]["text"] += text
798803
else:
799-
groups.append({
800-
'channel': channel,
801-
'recipient': recipient,
802-
'text': text
803-
})
804+
groups.append(
805+
{
806+
"channel": channel,
807+
"recipient": recipient,
808+
"text": text,
809+
}
810+
)
804811

805812
# Process each group and create delta messages
806813
delta_message = None
@@ -809,70 +816,97 @@ async def chat_completion_stream_generator(
809816
tool_messages = []
810817

811818
# Calculate base_index once before the loop
812-
# This represents the number of completed tool calls
819+
# This counts completed tool calls in messages
813820
base_index = 0
814821
for msg in harmony_parser.messages:
815-
if (msg.channel == "commentary"
816-
and msg.recipient
817-
and msg.recipient.startswith(
818-
"functions.")):
822+
if (
823+
msg.channel == "commentary"
824+
and msg.recipient
825+
and msg.recipient.startswith("functions.")
826+
):
819827
base_index += 1
820828

821-
# next_tool_index tracks the index for the next NEW tool call
822-
next_tool_index = base_index
829+
# If there's an ongoing tool call from previous chunk,
830+
# the next new tool call starts at base_index + 1
831+
if prev_recipient and prev_recipient.startswith("functions."):
832+
next_tool_index = base_index + 1
833+
# Ongoing call is at base_index
834+
ongoing_tool_index = base_index
835+
else:
836+
# No ongoing call, next new call is at base_index
837+
next_tool_index = base_index
838+
ongoing_tool_index = None
823839

824840
for group in groups:
825-
group_channel = group['channel']
826-
group_recipient = group['recipient']
827-
group_text = group['text']
841+
group_channel = group["channel"]
842+
group_recipient = group["recipient"]
843+
group_text = group["text"]
828844

829845
if group_channel == "final":
830846
combined_content += group_text
831847
elif group_channel == "analysis":
832848
if request.include_reasoning:
833849
combined_reasoning += group_text
834-
elif (group_channel == "commentary" and group_recipient
835-
and group_recipient.startswith("functions.")):
836-
850+
elif (
851+
group_channel == "commentary"
852+
and group_recipient
853+
and group_recipient.startswith("functions.")
854+
):
855+
opened_new_call = False
837856
if prev_recipient != group_recipient:
838857
# New tool call - emit the opening message
839-
tool_name = group_recipient.split(
840-
"functions.", 1)[1]
841-
tool_messages.append(DeltaToolCall(
842-
id=make_tool_call_id(),
843-
type="function",
844-
function=DeltaFunctionCall(
845-
name=tool_name,
846-
arguments="",
847-
),
848-
index=next_tool_index,
849-
))
858+
tool_name = group_recipient.split("functions.", 1)[
859+
1
860+
]
861+
tool_messages.append(
862+
DeltaToolCall(
863+
id=make_tool_call_id(),
864+
type="function",
865+
function=DeltaFunctionCall(
866+
name=tool_name,
867+
arguments="",
868+
),
869+
index=next_tool_index,
870+
)
871+
)
872+
opened_new_call = True
850873
prev_recipient = group_recipient
851874
# Increment for subsequent new tool calls
852875
next_tool_index += 1
853876

854877
if group_text:
855878
# Stream arguments for the ongoing tool call
856-
# Use next_tool_index - 1 if we opened a call
857-
# this chunk, else base_index for ongoing
858-
tool_call_index = (next_tool_index - 1
859-
if next_tool_index > base_index
860-
else base_index)
861-
tool_messages.append(DeltaToolCall(
862-
index=tool_call_index,
863-
function=DeltaFunctionCall(
864-
arguments=group_text),
865-
))
879+
if opened_new_call:
880+
# Just opened in this group
881+
tool_call_index = next_tool_index - 1
882+
else:
883+
# Continuing from previous chunk
884+
# If ongoing_tool_index is None here, it means
885+
# we're continuing a call but prev_recipient
886+
# wasn't a function. Use base_index.
887+
tool_call_index = (
888+
ongoing_tool_index
889+
if ongoing_tool_index is not None
890+
else base_index
891+
)
892+
tool_messages.append(
893+
DeltaToolCall(
894+
index=tool_call_index,
895+
function=DeltaFunctionCall(
896+
arguments=group_text
897+
),
898+
)
899+
)
866900

867901
# Combine all non-empty fields into a single message
868902
if combined_content or combined_reasoning or tool_messages:
869903
delta_kwargs: dict[str, Any] = {}
870904
if combined_content:
871-
delta_kwargs['content'] = combined_content
905+
delta_kwargs["content"] = combined_content
872906
if combined_reasoning:
873-
delta_kwargs['reasoning_content'] = combined_reasoning
907+
delta_kwargs["reasoning_content"] = combined_reasoning
874908
if tool_messages:
875-
delta_kwargs['tool_calls'] = tool_messages
909+
delta_kwargs["tool_calls"] = tool_messages
876910
harmony_tools_streamed[i] = True
877911

878912
delta_message = DeltaMessage(**delta_kwargs)
@@ -1124,7 +1158,8 @@ async def chat_completion_stream_generator(
11241158
tool_args = "".join(
11251159
tc.function.arguments
11261160
for tc in delta_message.tool_calls
1127-
if tc.function and tc.function.arguments)
1161+
if tc.function and tc.function.arguments
1162+
)
11281163
if tool_args:
11291164
delta_content_parts.append(f"[tool_calls: {tool_args}]")
11301165

0 commit comments

Comments
 (0)