33# Copyright 2025 The Qwen team.
44"""Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition."""
55
6+ import logging
67from typing import Any
78
89import torch
1213from vllm_omni .engine import OmniEngineCoreRequest
1314from vllm_omni .inputs .data import OmniTokensPrompt
1415
16+ logger = logging .getLogger (__name__ )
17+
1518
1619def _compute_talker_prompt_ids_length (info , device : torch .device | str = "cuda" ) -> int :
1720 im_start_token_id = 151644
@@ -141,6 +144,62 @@ def thinker2talker_async_chunk(
141144 return talker_additional_info
142145
143146
147+ def _get_prefill_stage (stage_list : list [Any ], source_stage_id : int ) -> Any | None :
148+ """Return the preceding prefill stage if PD disaggregation is active."""
149+ if source_stage_id <= 0 :
150+ return None
151+ source_stage = stage_list [source_stage_id ]
152+ if not getattr (source_stage , "is_decode_only" , False ):
153+ return None
154+ prev_stage = stage_list [source_stage_id - 1 ]
155+ if (
156+ getattr (prev_stage , "is_prefill_only" , False )
157+ and prev_stage .engine_outputs is not None
158+ ):
159+ return prev_stage
160+ return None
161+
162+
163+ def _merge_pd_embeddings (
164+ decode_emb : torch .Tensor ,
165+ decode_hid : torch .Tensor ,
166+ prefill_mm : dict [str , Any ],
167+ device : torch .device ,
168+ ) -> tuple [torch .Tensor , torch .Tensor ]:
169+ """Merge prefill prompt embeddings with decode generated embeddings.
170+
171+ In PD disaggregation the decode engine only produces embeddings for the
172+ tokens it actually computed (1 remaining-prompt + N generated). The
173+ prefill engine has embeddings for the full prompt. We concatenate them
174+ so the talker sees the complete sequence::
175+
176+ merged = prefill[0 : prompt_len] + decode[1:]
177+ ^ ^
178+ prompt positions generated positions
179+ (skip overlap at last prompt pos)
180+ """
181+ try :
182+ p_emb = prefill_mm ["0" ].detach ().to (device = device , dtype = torch .float )
183+ p_hid = prefill_mm ["24" ].detach ().to (device = device , dtype = torch .float )
184+ except (KeyError , AttributeError , TypeError ):
185+ return decode_emb , decode_hid
186+
187+ if p_emb .shape [0 ] == 0 or decode_emb .shape [0 ] == 0 :
188+ return decode_emb , decode_hid
189+
190+ # decode[0] is the recomputed last-prompt-token (overlap with prefill[-1]).
191+ merged_emb = torch .cat ([p_emb , decode_emb [1 :]], dim = 0 )
192+ merged_hid = torch .cat ([p_hid , decode_hid [1 :]], dim = 0 )
193+
194+ logger .info (
195+ "[PD] Merged prefill(%d) + decode(%d) → %d embeddings" ,
196+ p_emb .shape [0 ],
197+ decode_emb .shape [0 ],
198+ merged_emb .shape [0 ],
199+ )
200+ return merged_emb , merged_hid
201+
202+
144203def thinker2talker (
145204 stage_list : list [Any ],
146205 engine_input_source : list [int ],
@@ -155,6 +214,12 @@ def thinker2talker(
155214 2. Split hidden states into: prompt embeddings + generated embeddings
156215 3. Package for talker with additional information
157216
217+ In PD disaggregation the decode engine's multimodal_output only covers
218+ the tokens it computed (not the full prompt). When a preceding prefill
219+ stage is detected we merge the prefill's prompt embeddings with the
220+ decode's generated embeddings so the talker receives the complete
221+ sequence.
222+
158223 Args:
159224 stage_list: List of stage objects
160225 engine_input_source: Source stage IDs (typically [0] for thinker)
@@ -169,21 +234,55 @@ def thinker2talker(
169234
170235 device = torch .device (current_platform .device_type )
171236
237+ # PD disaggregation: look for a preceding prefill stage whose
238+ # embeddings we need to merge with the decode output.
239+ source_stage_id = engine_input_source [0 ]
240+ prefill_stage = _get_prefill_stage (stage_list , source_stage_id )
241+
172242 # Process each thinker output
173- for thinker_output in thinker_outputs :
243+ for i , thinker_output in enumerate ( thinker_outputs ) :
174244 output = thinker_output .outputs [0 ]
175245
246+ decode_emb = output .multimodal_output ["0" ].detach ().to (device = device , dtype = torch .float )
247+ decode_hid = output .multimodal_output ["24" ].detach ().to (device = device , dtype = torch .float )
248+
249+ # Merge prefill prompt embeddings when running in PD mode.
250+ if prefill_stage is not None :
251+ try :
252+ prefill_eos = prefill_stage .engine_outputs
253+ prefill_eo = prefill_eos [min (i , len (prefill_eos ) - 1 )]
254+ prefill_mm = prefill_eo .outputs [0 ].multimodal_output
255+ decode_emb , decode_hid = _merge_pd_embeddings (
256+ decode_emb , decode_hid , prefill_mm , device ,
257+ )
258+ except Exception as exc :
259+ logger .warning ("[PD] Could not merge prefill embeddings: %s" , exc )
260+
261+ # Helper: get TTS embed from decode, fall back to prefill if missing.
262+ def _tts (key : str ) -> torch .Tensor :
263+ val = output .multimodal_output .get (key )
264+ if val is None and prefill_stage is not None :
265+ try :
266+ val = (
267+ prefill_stage .engine_outputs [0 ]
268+ .outputs [0 ]
269+ .multimodal_output .get (key )
270+ )
271+ except Exception :
272+ pass
273+ return val .detach ().to (device = device , dtype = torch .float ) if val is not None else None
274+
176275 info = {
177- "thinker_embeddings" : output . multimodal_output [ "0" ]. detach (). to ( device = device , dtype = torch . float ) ,
178- "thinker_hidden_states" : output . multimodal_output [ "24" ]. detach (). to ( device = device , dtype = torch . float ) ,
276+ "thinker_embeddings" : decode_emb ,
277+ "thinker_hidden_states" : decode_hid ,
179278 "thinker_sequences" : (
180279 thinker_output .prompt_token_ids + output .token_ids
181280 ), # the thinker_sequences is the whole ids
182281 "thinker_input_ids" : thinker_output .prompt_token_ids ,
183282 # Provide thinker-side TTS token embeddings for talker projection
184- "tts_bos_embed" : output . multimodal_output [ "tts_bos_embed" ]. detach (). to ( device = device , dtype = torch . float ),
185- "tts_eos_embed" : output . multimodal_output [ "tts_eos_embed" ]. detach (). to ( device = device , dtype = torch . float ),
186- "tts_pad_embed" : output . multimodal_output [ "tts_pad_embed" ]. detach (). to ( device = device , dtype = torch . float ),
283+ "tts_bos_embed" : _tts ( "tts_bos_embed" ),
284+ "tts_eos_embed" : _tts ( "tts_eos_embed" ),
285+ "tts_pad_embed" : _tts ( "tts_pad_embed" ),
187286 }
188287
189288 prompt_len = _compute_talker_prompt_ids_length (info , device = device )
0 commit comments