Skip to content

Commit 4a2a62c

Browse files
ahengljhclaude
andcommitted
[Bugfix] Merge prefill + decode embeddings in thinker2talker for PD disaggregation
In PD mode the decode engine's multimodal_output only covers tokens it computed (~9), but thinker_sequences has the full prompt + generated tokens (~20). This misalignment caused the talker to map embeddings to wrong token positions, producing garbled output. Now thinker2talker detects PD mode via a preceding prefill stage and merges prefill prompt embeddings with decode generated embeddings so the talker receives the complete, correctly-aligned sequence. TTS embeds also fall back to the prefill stage if missing from decode. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d8213d5 commit 4a2a62c

File tree

1 file changed

+105
-6
lines changed

1 file changed

+105
-6
lines changed

vllm_omni/model_executor/stage_input_processors/qwen3_omni.py

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Copyright 2025 The Qwen team.
44
"""Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition."""
55

6+
import logging
67
from typing import Any
78

89
import torch
@@ -12,6 +13,8 @@
1213
from vllm_omni.engine import OmniEngineCoreRequest
1314
from vllm_omni.inputs.data import OmniTokensPrompt
1415

16+
logger = logging.getLogger(__name__)
17+
1518

1619
def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int:
1720
im_start_token_id = 151644
@@ -141,6 +144,62 @@ def thinker2talker_async_chunk(
141144
return talker_additional_info
142145

143146

147+
def _get_prefill_stage(stage_list: list[Any], source_stage_id: int) -> Any | None:
148+
"""Return the preceding prefill stage if PD disaggregation is active."""
149+
if source_stage_id <= 0:
150+
return None
151+
source_stage = stage_list[source_stage_id]
152+
if not getattr(source_stage, "is_decode_only", False):
153+
return None
154+
prev_stage = stage_list[source_stage_id - 1]
155+
if (
156+
getattr(prev_stage, "is_prefill_only", False)
157+
and prev_stage.engine_outputs is not None
158+
):
159+
return prev_stage
160+
return None
161+
162+
163+
def _merge_pd_embeddings(
164+
decode_emb: torch.Tensor,
165+
decode_hid: torch.Tensor,
166+
prefill_mm: dict[str, Any],
167+
device: torch.device,
168+
) -> tuple[torch.Tensor, torch.Tensor]:
169+
"""Merge prefill prompt embeddings with decode generated embeddings.
170+
171+
In PD disaggregation the decode engine only produces embeddings for the
172+
tokens it actually computed (1 remaining-prompt + N generated). The
173+
prefill engine has embeddings for the full prompt. We concatenate them
174+
so the talker sees the complete sequence::
175+
176+
merged = prefill[0 : prompt_len] + decode[1:]
177+
^ ^
178+
prompt positions generated positions
179+
(skip overlap at last prompt pos)
180+
"""
181+
try:
182+
p_emb = prefill_mm["0"].detach().to(device=device, dtype=torch.float)
183+
p_hid = prefill_mm["24"].detach().to(device=device, dtype=torch.float)
184+
except (KeyError, AttributeError, TypeError):
185+
return decode_emb, decode_hid
186+
187+
if p_emb.shape[0] == 0 or decode_emb.shape[0] == 0:
188+
return decode_emb, decode_hid
189+
190+
# decode[0] is the recomputed last-prompt-token (overlap with prefill[-1]).
191+
merged_emb = torch.cat([p_emb, decode_emb[1:]], dim=0)
192+
merged_hid = torch.cat([p_hid, decode_hid[1:]], dim=0)
193+
194+
logger.info(
195+
"[PD] Merged prefill(%d) + decode(%d) → %d embeddings",
196+
p_emb.shape[0],
197+
decode_emb.shape[0],
198+
merged_emb.shape[0],
199+
)
200+
return merged_emb, merged_hid
201+
202+
144203
def thinker2talker(
145204
stage_list: list[Any],
146205
engine_input_source: list[int],
@@ -155,6 +214,12 @@ def thinker2talker(
155214
2. Split hidden states into: prompt embeddings + generated embeddings
156215
3. Package for talker with additional information
157216
217+
In PD disaggregation the decode engine's multimodal_output only covers
218+
the tokens it computed (not the full prompt). When a preceding prefill
219+
stage is detected we merge the prefill's prompt embeddings with the
220+
decode's generated embeddings so the talker receives the complete
221+
sequence.
222+
158223
Args:
159224
stage_list: List of stage objects
160225
engine_input_source: Source stage IDs (typically [0] for thinker)
@@ -169,21 +234,55 @@ def thinker2talker(
169234

170235
device = torch.device(current_platform.device_type)
171236

237+
# PD disaggregation: look for a preceding prefill stage whose
238+
# embeddings we need to merge with the decode output.
239+
source_stage_id = engine_input_source[0]
240+
prefill_stage = _get_prefill_stage(stage_list, source_stage_id)
241+
172242
# Process each thinker output
173-
for thinker_output in thinker_outputs:
243+
for i, thinker_output in enumerate(thinker_outputs):
174244
output = thinker_output.outputs[0]
175245

246+
decode_emb = output.multimodal_output["0"].detach().to(device=device, dtype=torch.float)
247+
decode_hid = output.multimodal_output["24"].detach().to(device=device, dtype=torch.float)
248+
249+
# Merge prefill prompt embeddings when running in PD mode.
250+
if prefill_stage is not None:
251+
try:
252+
prefill_eos = prefill_stage.engine_outputs
253+
prefill_eo = prefill_eos[min(i, len(prefill_eos) - 1)]
254+
prefill_mm = prefill_eo.outputs[0].multimodal_output
255+
decode_emb, decode_hid = _merge_pd_embeddings(
256+
decode_emb, decode_hid, prefill_mm, device,
257+
)
258+
except Exception as exc:
259+
logger.warning("[PD] Could not merge prefill embeddings: %s", exc)
260+
261+
# Helper: get TTS embed from decode, fall back to prefill if missing.
262+
def _tts(key: str) -> torch.Tensor:
263+
val = output.multimodal_output.get(key)
264+
if val is None and prefill_stage is not None:
265+
try:
266+
val = (
267+
prefill_stage.engine_outputs[0]
268+
.outputs[0]
269+
.multimodal_output.get(key)
270+
)
271+
except Exception:
272+
pass
273+
return val.detach().to(device=device, dtype=torch.float) if val is not None else None
274+
176275
info = {
177-
"thinker_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float),
178-
"thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float),
276+
"thinker_embeddings": decode_emb,
277+
"thinker_hidden_states": decode_hid,
179278
"thinker_sequences": (
180279
thinker_output.prompt_token_ids + output.token_ids
181280
), # the thinker_sequences is the whole ids
182281
"thinker_input_ids": thinker_output.prompt_token_ids,
183282
# Provide thinker-side TTS token embeddings for talker projection
184-
"tts_bos_embed": output.multimodal_output["tts_bos_embed"].detach().to(device=device, dtype=torch.float),
185-
"tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float),
186-
"tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float),
283+
"tts_bos_embed": _tts("tts_bos_embed"),
284+
"tts_eos_embed": _tts("tts_eos_embed"),
285+
"tts_pad_embed": _tts("tts_pad_embed"),
187286
}
188287

189288
prompt_len = _compute_talker_prompt_ids_length(info, device=device)

0 commit comments

Comments
 (0)