Skip to content

Commit bdeac74

Browse files
committed
Handle multiple channels in one decoding stage
1 parent f9e7148 commit bdeac74

File tree

1 file changed

+83
-43
lines changed

1 file changed

+83
-43
lines changed

vllm/entrypoints/openai/serving_chat.py

Lines changed: 83 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -677,11 +677,18 @@ async def chat_completion_stream_generator(
677677
if self.use_harmony:
678678
harmony_parser = harmony_parsers[i]
679679
prev_recipient = harmony_parser.current_recipient
680+
681+
# Track accumulated content per token with their state
682+
token_states = []
680683
for token_id in output.token_ids:
681684
harmony_parser.process(token_id)
682-
cur_channel = harmony_parser.current_channel
683-
cur_recipient = harmony_parser.current_recipient
684-
delta_text = harmony_parser.last_content_delta or ""
685+
token_delta = harmony_parser.last_content_delta or ""
686+
token_states.append((
687+
harmony_parser.current_channel,
688+
harmony_parser.current_recipient,
689+
token_delta
690+
))
691+
delta_text = "".join(state[2] for state in token_states)
685692
else:
686693
delta_text = output.text
687694

@@ -707,52 +714,80 @@ async def chat_completion_stream_generator(
707714
current_token_ids = as_list(output.token_ids)
708715

709716
if self.use_harmony:
710-
if cur_channel == "final":
711-
delta_message = DeltaMessage(content=delta_text)
712-
elif cur_channel == "analysis":
713-
if request.include_reasoning:
714-
delta_message = DeltaMessage(
715-
reasoning_content=delta_text)
717+
# Group consecutive tokens with same channel/recipient
718+
groups = []
719+
for channel, recipient, text in token_states:
720+
if not text:
721+
continue
722+
if groups and groups[-1]['channel'] == channel and groups[-1]['recipient'] == recipient:
723+
groups[-1]['text'] += text
716724
else:
717-
delta_message = None
718-
elif (cur_channel == "commentary" and cur_recipient
719-
and cur_recipient.startswith("functions.")):
720-
# Count completed tool calls to determine index
721-
base_index = 0
722-
for msg in harmony_parser.messages:
723-
if (msg.channel == "commentary"
724-
and msg.recipient
725-
and msg.recipient.startswith(
726-
"functions.")):
727-
base_index += 1
728-
729-
if prev_recipient != cur_recipient:
730-
tool_name = cur_recipient.split(
731-
"functions.", 1)[1]
732-
delta_message = DeltaMessage(tool_calls=[
733-
DeltaToolCall(
725+
groups.append({
726+
'channel': channel,
727+
'recipient': recipient,
728+
'text': text
729+
})
730+
731+
# Process each group and create delta messages
732+
delta_message = None
733+
combined_content = ""
734+
combined_reasoning = ""
735+
tool_messages = []
736+
737+
for group in groups:
738+
group_channel = group['channel']
739+
group_recipient = group['recipient']
740+
group_text = group['text']
741+
742+
if group_channel == "final":
743+
combined_content += group_text
744+
elif group_channel == "analysis":
745+
if request.include_reasoning:
746+
combined_reasoning += group_text
747+
elif (group_channel == "commentary" and group_recipient
748+
and group_recipient.startswith("functions.")):
749+
750+
base_index = 0
751+
for msg in harmony_parser.messages:
752+
if (msg.channel == "commentary"
753+
and msg.recipient
754+
and msg.recipient.startswith(
755+
"functions.")):
756+
base_index += 1
757+
758+
if prev_recipient != group_recipient:
759+
tool_name = group_recipient.split(
760+
"functions.", 1)[1]
761+
tool_messages.append(DeltaToolCall(
734762
id=make_tool_call_id(),
735763
type="function",
736764
function=DeltaFunctionCall(
737765
name=tool_name,
738766
arguments="",
739767
),
740768
index=base_index,
741-
)
742-
])
743-
elif delta_text:
744-
delta_message = DeltaMessage(tool_calls=[
745-
DeltaToolCall(
769+
))
770+
prev_recipient = group_recipient
771+
772+
if group_text:
773+
tool_messages.append(DeltaToolCall(
746774
index=base_index,
747775
function=DeltaFunctionCall(
748-
arguments=delta_text),
749-
)
750-
])
751-
else:
752-
delta_message = None
753-
754-
if delta_message is not None:
776+
arguments=group_text),
777+
))
778+
779+
# Combine all non-empty fields into a single message
780+
if combined_content or combined_reasoning or tool_messages:
781+
delta_kwargs = {}
782+
if combined_content:
783+
delta_kwargs['content'] = combined_content
784+
if combined_reasoning:
785+
delta_kwargs['reasoning_content'] = combined_reasoning
786+
if tool_messages:
787+
delta_kwargs['tool_calls'] = tool_messages
755788
harmony_tools_streamed[i] = True
789+
790+
delta_message = DeltaMessage(**delta_kwargs)
756791
else:
757792
delta_message = None
758793
# handle streaming deltas for tools with named tool_choice
@@ -971,16 +1006,21 @@ async def chat_completion_stream_generator(
9711006

9721007
# Log streaming delta if output logging is enabled
9731008
if self.enable_log_outputs and self.request_logger:
974-
delta_content = ""
1009+
delta_content_parts = []
9751010
if delta_message.content:
976-
delta_content = delta_message.content
977-
elif delta_message.tool_calls:
978-
delta_content = "".join(
1011+
delta_content_parts.append(delta_message.content)
1012+
if delta_message.reasoning_content:
1013+
delta_content_parts.append(f"[reasoning: {delta_message.reasoning_content}]")
1014+
if delta_message.tool_calls:
1015+
tool_args = "".join(
9791016
tc.function.arguments
9801017
for tc in delta_message.tool_calls
9811018
if tc.function and tc.function.arguments)
1019+
if tool_args:
1020+
delta_content_parts.append(f"[tool_calls: {tool_args}]")
9821021

983-
if delta_content:
1022+
if delta_content_parts:
1023+
delta_content = " ".join(delta_content_parts)
9841024
self.request_logger.log_outputs(
9851025
request_id=request_id,
9861026
outputs=delta_content,

0 commit comments

Comments
 (0)