Skip to content

Commit 606e7cf

Browse files
ahengljhclaude
andcommitted
[Bugfix] Fix PD embedding merge: dynamic overlap instead of blind decode[1:] skip
The previous merge blindly skipped decode[0] assuming it overlaps with prefill[-1]. In practice prefill=12 + decode=8 with expected_total=20 means there is NO overlap (12+8=20), so decode[0] should NOT be skipped. The decode[1:] skip produced 19 embeddings for a 20-token sequence, causing misaligned talker input and garbled output. Now _merge_pd_embeddings computes overlap dynamically: overlap = max(0, prefill_len + decode_len - expected_total) This correctly handles both overlap and no-overlap cases. Also added diagnostic logging for prompt_len, output_len, expected_total, and actual embedding shapes to make future debugging easier. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c4a5c27 commit 606e7cf

File tree

1 file changed

+37
-12
lines changed

1 file changed

+37
-12
lines changed

vllm_omni/model_executor/stage_input_processors/qwen3_omni.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,21 @@ def _merge_pd_embeddings(
165165
decode_hid: torch.Tensor,
166166
prefill_mm: dict[str, Any],
167167
device: torch.device,
168+
expected_total: int | None = None,
168169
) -> tuple[torch.Tensor, torch.Tensor]:
169170
"""Merge prefill prompt embeddings with decode generated embeddings.
170171
171172
In PD disaggregation the decode engine only produces embeddings for the
172-
tokens it actually computed (1 remaining-prompt + N generated). The
173-
prefill engine has embeddings for the full prompt. We concatenate them
174-
so the talker sees the complete sequence::
175-
176-
merged = prefill[0 : prompt_len] + decode[1:]
177-
^ ^
178-
prompt positions generated positions
179-
(skip overlap at last prompt pos)
173+
tokens it actually computed. The prefill engine has embeddings for the
174+
full prompt. We concatenate them, dynamically computing any overlap::
175+
176+
overlap = prefill_len + decode_len - expected_total
177+
merged = prefill + decode[overlap:]
178+
179+
When ``expected_total`` (= len(prompt_token_ids) + len(output.token_ids))
180+
is provided we use it to decide how many leading decode embeddings to
181+
skip (they duplicate trailing prefill positions). If not provided we
182+
fall back to no-skip concatenation.
180183
"""
181184
try:
182185
p_emb = prefill_mm["0"].detach().to(device=device, dtype=torch.float)
@@ -187,15 +190,23 @@ def _merge_pd_embeddings(
187190
if p_emb.shape[0] == 0 or decode_emb.shape[0] == 0:
188191
return decode_emb, decode_hid
189192

190-
# decode[0] is the recomputed last-prompt-token (overlap with prefill[-1]).
191-
merged_emb = torch.cat([p_emb, decode_emb[1:]], dim=0)
192-
merged_hid = torch.cat([p_hid, decode_hid[1:]], dim=0)
193+
raw_total = p_emb.shape[0] + decode_emb.shape[0]
194+
if expected_total is not None and raw_total > expected_total:
195+
overlap = raw_total - expected_total
196+
else:
197+
overlap = 0
198+
199+
merged_emb = torch.cat([p_emb, decode_emb[overlap:]], dim=0)
200+
merged_hid = torch.cat([p_hid, decode_hid[overlap:]], dim=0)
193201

194202
logger.info(
195-
"[PD] Merged prefill(%d) + decode(%d) → %d embeddings",
203+
"[PD] Merged prefill(%d) + decode(%d) overlap=%d → %d embeddings "
204+
"(expected=%s)",
196205
p_emb.shape[0],
197206
decode_emb.shape[0],
207+
overlap,
198208
merged_emb.shape[0],
209+
expected_total,
199210
)
200211
return merged_emb, merged_hid
201212

@@ -246,6 +257,19 @@ def thinker2talker(
246257
decode_emb = output.multimodal_output["0"].detach().to(device=device, dtype=torch.float)
247258
decode_hid = output.multimodal_output["24"].detach().to(device=device, dtype=torch.float)
248259

260+
# Expected total = prompt tokens + generated tokens (the full sequence).
261+
expected_total = len(thinker_output.prompt_token_ids) + len(output.token_ids)
262+
263+
logger.info(
264+
"[PD] thinker2talker: prompt_len=%d, output_len=%d, "
265+
"expected_total=%d, decode_emb=%d, decode_hid=%d",
266+
len(thinker_output.prompt_token_ids),
267+
len(output.token_ids),
268+
expected_total,
269+
decode_emb.shape[0],
270+
decode_hid.shape[0],
271+
)
272+
249273
# Merge prefill prompt embeddings when running in PD mode.
250274
if prefill_stage is not None:
251275
try:
@@ -254,6 +278,7 @@ def thinker2talker(
254278
prefill_mm = prefill_eo.outputs[0].multimodal_output
255279
decode_emb, decode_hid = _merge_pd_embeddings(
256280
decode_emb, decode_hid, prefill_mm, device,
281+
expected_total=expected_total,
257282
)
258283
except Exception as exc:
259284
logger.warning("[PD] Could not merge prefill embeddings: %s", exc)

0 commit comments

Comments
 (0)