Skip to content
Open
178 changes: 128 additions & 50 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union
from typing import Any, Callable, Final, Optional, Union

import jinja2
import partial_json_parser
Expand Down Expand Up @@ -748,12 +748,20 @@ async def chat_completion_stream_generator(
if self.use_harmony:
harmony_parser = harmony_parsers[i]
prev_recipient = harmony_parser.current_recipient
delta_text = ""

# Track accumulated content per token with their state
token_states = []
for token_id in output.token_ids:
harmony_parser.process(token_id)
delta_text += harmony_parser.last_content_delta or ""
cur_channel = harmony_parser.current_channel
cur_recipient = harmony_parser.current_recipient
token_delta = harmony_parser.last_content_delta or ""
token_states.append(
(
harmony_parser.current_channel,
harmony_parser.current_recipient,
token_delta,
)
)
delta_text = "".join(delta for _, _, delta in token_states)
else:
delta_text = output.text

Expand Down Expand Up @@ -783,61 +791,125 @@ async def chat_completion_stream_generator(
current_token_ids = as_list(output.token_ids)

if self.use_harmony:
if cur_channel == "final":
delta_message = DeltaMessage(content=delta_text)
elif cur_channel == "analysis":
if request.include_reasoning:
delta_message = DeltaMessage(
reasoning_content=delta_text
)
# Group consecutive tokens with same channel/recipient
groups: list[dict[str, str]] = []
for channel, recipient, text in token_states:
if (
groups
and groups[-1]["channel"] == channel
and groups[-1]["recipient"] == recipient
):
groups[-1]["text"] += text
else:
delta_message = None
elif (
cur_channel == "commentary"
and cur_recipient
and cur_recipient.startswith("functions.")
):
# Count completed tool calls to determine index
base_index = 0
for msg in harmony_parser.messages:
if (
msg.channel == "commentary"
and msg.recipient
and msg.recipient.startswith("functions.")
):
base_index += 1

if prev_recipient != cur_recipient:
tool_name = cur_recipient.split("functions.", 1)[1]
delta_message = DeltaMessage(
tool_calls=[
groups.append(
{
"channel": channel,
"recipient": recipient,
"text": text,
}
)

# Process each group and create delta messages
delta_message = None
combined_content = ""
combined_reasoning = ""
tool_messages = []

# Calculate base_index once before the loop
# This counts completed tool calls in messages
base_index = 0
for msg in harmony_parser.messages:
if (
msg.channel == "commentary"
and msg.recipient
and msg.recipient.startswith("functions.")
):
base_index += 1

# If there's an ongoing tool call from previous chunk,
# the next new tool call starts at base_index + 1
if prev_recipient and prev_recipient.startswith("functions."):
next_tool_index = base_index + 1
# Ongoing call is at base_index
ongoing_tool_index = base_index
else:
# No ongoing call, next new call is at base_index
next_tool_index = base_index
ongoing_tool_index = None

for group in groups:
group_channel = group["channel"]
group_recipient = group["recipient"]
group_text = group["text"]

if group_channel == "final":
combined_content += group_text
elif group_channel == "analysis":
if request.include_reasoning:
combined_reasoning += group_text
elif (
group_channel == "commentary"
and group_recipient
and group_recipient.startswith("functions.")
):
opened_new_call = False
if prev_recipient != group_recipient:
# New tool call - emit the opening message
tool_name = group_recipient.split("functions.", 1)[
1
]
tool_messages.append(
DeltaToolCall(
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_name,
arguments="",
),
index=base_index,
index=next_tool_index,
)
]
)
elif delta_text:
delta_message = DeltaMessage(
tool_calls=[
)
opened_new_call = True
prev_recipient = group_recipient
# Increment for subsequent new tool calls
next_tool_index += 1

if group_text:
# Stream arguments for the ongoing tool call
if opened_new_call:
# Just opened in this group
tool_call_index = next_tool_index - 1
else:
# Continuing from previous chunk
# If ongoing_tool_index is None here, it means
# we're continuing a call but prev_recipient
# wasn't a function. Use base_index.
tool_call_index = (
ongoing_tool_index
if ongoing_tool_index is not None
else base_index
)
tool_messages.append(
DeltaToolCall(
index=base_index,
index=tool_call_index,
function=DeltaFunctionCall(
arguments=delta_text
arguments=group_text
),
)
]
)
else:
delta_message = None
)

if delta_message is not None:
# Combine all non-empty fields into a single message
if combined_content or combined_reasoning or tool_messages:
delta_kwargs: dict[str, Any] = {}
if combined_content:
delta_kwargs["content"] = combined_content
if combined_reasoning:
delta_kwargs["reasoning_content"] = combined_reasoning
if tool_messages:
delta_kwargs["tool_calls"] = tool_messages
harmony_tools_streamed[i] = True

delta_message = DeltaMessage(**delta_kwargs)
else:
delta_message = None
# handle streaming deltas for tools with named tool_choice
Expand Down Expand Up @@ -1076,17 +1148,23 @@ async def chat_completion_stream_generator(

# Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger:
delta_content = ""
delta_content_parts = []
if delta_message.content:
delta_content = delta_message.content
elif delta_message.tool_calls:
delta_content = "".join(
delta_content_parts.append(delta_message.content)
if delta_message.reasoning_content:
reasoning = delta_message.reasoning_content
delta_content_parts.append(f"[reasoning: {reasoning}]")
if delta_message.tool_calls:
tool_args = "".join(
tc.function.arguments
for tc in delta_message.tool_calls
if tc.function and tc.function.arguments
)
if tool_args:
delta_content_parts.append(f"[tool_calls: {tool_args}]")

if delta_content:
if delta_content_parts:
delta_content = " ".join(delta_content_parts)
self.request_logger.log_outputs(
request_id=request_id,
outputs=delta_content,
Expand Down