diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 12dd474936db..c8485dc00d80 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -6,7 +6,7 @@ import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence -from typing import Callable, Final, Optional, Union +from typing import Any, Callable, Final, Optional, Union import jinja2 import partial_json_parser @@ -748,12 +748,20 @@ async def chat_completion_stream_generator( if self.use_harmony: harmony_parser = harmony_parsers[i] prev_recipient = harmony_parser.current_recipient - delta_text = "" + + # Track accumulated content per token with their state + token_states = [] for token_id in output.token_ids: harmony_parser.process(token_id) - delta_text += harmony_parser.last_content_delta or "" - cur_channel = harmony_parser.current_channel - cur_recipient = harmony_parser.current_recipient + token_delta = harmony_parser.last_content_delta or "" + token_states.append( + ( + harmony_parser.current_channel, + harmony_parser.current_recipient, + token_delta, + ) + ) + delta_text = "".join(delta for _, _, delta in token_states) else: delta_text = output.text @@ -783,34 +791,74 @@ async def chat_completion_stream_generator( current_token_ids = as_list(output.token_ids) if self.use_harmony: - if cur_channel == "final": - delta_message = DeltaMessage(content=delta_text) - elif cur_channel == "analysis": - if request.include_reasoning: - delta_message = DeltaMessage( - reasoning_content=delta_text - ) + # Group consecutive tokens with same channel/recipient + groups: list[dict[str, str]] = [] + for channel, recipient, text in token_states: + if ( + groups + and groups[-1]["channel"] == channel + and groups[-1]["recipient"] == recipient + ): + groups[-1]["text"] += text else: - delta_message = None - elif ( - cur_channel == "commentary" - and cur_recipient - and cur_recipient.startswith("functions.") - ): - # Count completed tool calls to determine index - base_index = 0 - for msg in harmony_parser.messages: - if ( - msg.channel == "commentary" - and msg.recipient - and msg.recipient.startswith("functions.") - ): - base_index += 1 - - if prev_recipient != cur_recipient: - tool_name = cur_recipient.split("functions.", 1)[1] - delta_message = DeltaMessage( - tool_calls=[ + groups.append( + { + "channel": channel, + "recipient": recipient, + "text": text, + } + ) + + # Process each group and create delta messages + delta_message = None + combined_content = "" + combined_reasoning = "" + tool_messages = [] + + # Calculate base_index once before the loop + # This counts completed tool calls in messages + base_index = 0 + for msg in harmony_parser.messages: + if ( + msg.channel == "commentary" + and msg.recipient + and msg.recipient.startswith("functions.") + ): + base_index += 1 + + # If there's an ongoing tool call from previous chunk, + # the next new tool call starts at base_index + 1 + if prev_recipient and prev_recipient.startswith("functions."): + next_tool_index = base_index + 1 + # Ongoing call is at base_index + ongoing_tool_index = base_index + else: + # No ongoing call, next new call is at base_index + next_tool_index = base_index + ongoing_tool_index = None + + for group in groups: + group_channel = group["channel"] + group_recipient = group["recipient"] + group_text = group["text"] + + if group_channel == "final": + combined_content += group_text + elif group_channel == "analysis": + if request.include_reasoning: + combined_reasoning += group_text + elif ( + group_channel == "commentary" + and group_recipient + and group_recipient.startswith("functions.") + ): + opened_new_call = False + if prev_recipient != group_recipient: + # New tool call - emit the opening message + tool_name = group_recipient.split("functions.", 1)[ + 1 + ] + tool_messages.append( DeltaToolCall( id=make_tool_call_id(), type="function", @@ -818,26 +866,50 @@ async def chat_completion_stream_generator( name=tool_name, arguments="", ), - index=base_index, + index=next_tool_index, ) - ] - ) - elif delta_text: - delta_message = DeltaMessage( - tool_calls=[ + ) + opened_new_call = True + prev_recipient = group_recipient + # Increment for subsequent new tool calls + next_tool_index += 1 + + if group_text: + # Stream arguments for the ongoing tool call + if opened_new_call: + # Just opened in this group + tool_call_index = next_tool_index - 1 + else: + # Continuing from previous chunk + # If ongoing_tool_index is None here, it means + # we're continuing a call but prev_recipient + # wasn't a function. Use base_index. + tool_call_index = ( + ongoing_tool_index + if ongoing_tool_index is not None + else base_index + ) + tool_messages.append( DeltaToolCall( - index=base_index, + index=tool_call_index, function=DeltaFunctionCall( - arguments=delta_text + arguments=group_text ), ) - ] - ) - else: - delta_message = None + ) - if delta_message is not None: + # Combine all non-empty fields into a single message + if combined_content or combined_reasoning or tool_messages: + delta_kwargs: dict[str, Any] = {} + if combined_content: + delta_kwargs["content"] = combined_content + if combined_reasoning: + delta_kwargs["reasoning_content"] = combined_reasoning + if tool_messages: + delta_kwargs["tool_calls"] = tool_messages harmony_tools_streamed[i] = True + + delta_message = DeltaMessage(**delta_kwargs) else: delta_message = None # handle streaming deltas for tools with named tool_choice @@ -1076,17 +1148,23 @@ async def chat_completion_stream_generator( # Log streaming delta if output logging is enabled if self.enable_log_outputs and self.request_logger: - delta_content = "" + delta_content_parts = [] if delta_message.content: - delta_content = delta_message.content - elif delta_message.tool_calls: - delta_content = "".join( + delta_content_parts.append(delta_message.content) + if delta_message.reasoning_content: + reasoning = delta_message.reasoning_content + delta_content_parts.append(f"[reasoning: {reasoning}]") + if delta_message.tool_calls: + tool_args = "".join( tc.function.arguments for tc in delta_message.tool_calls if tc.function and tc.function.arguments ) + if tool_args: + delta_content_parts.append(f"[tool_calls: {tool_args}]") - if delta_content: + if delta_content_parts: + delta_content = " ".join(delta_content_parts) self.request_logger.log_outputs( request_id=request_id, outputs=delta_content,