Skip to content

Commit 68e0f9e

Browse files
ahengljhclaude
andcommitted
[Feature] Add PD (Prefill-Decode) disaggregation for thinker stage
Split the thinker stage into separate prefill and decode instances that communicate via vLLM's native KV transfer (MooncakeConnector). The prefill engine processes prompts and saves KV cache; the decode engine loads the cache and generates tokens. Key changes: - PD detection, validation, and routing in OmniBase and AsyncOmni - Prefill sampling params: max_tokens=1, neutralize stop conditions - Patched MooncakeConnector with remote_request_id for cross-engine KV lookup - Monkey-patch infrastructure with vLLM version compatibility check - Embedding merge (prefill + decode) in thinker2talker stage processor - Zero-padding safety with threshold warning in talker model - Defense-in-depth cleanup of KV params after generation - Unit tests for PD detection, validation, routing, stop neutralization, failure modes, memory leak prevention, and TP validation - E2E tests for both text and audio modalities (offline + online) - PD CI stage config with load_format: dummy Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e37a89f commit 68e0f9e

File tree

36 files changed

+4868
-829
lines changed

36 files changed

+4868
-829
lines changed

docs/configuration/stage_configs.md

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,6 @@ Each stage in the `stage_args` list contains the following configuration options
135135

136136
A unique identifier for each stage in the multi-stage pipeline. Stages are numbered sequentially starting from 0, and this ID is used to reference stages in inter-stage dependencies (e.g., `engine_input_source`).
137137

138-
### `prompt_expand_func` (Optional)
139-
140-
A custom Python function hook for the LLM stage (Stage 0) that expands a single incoming prompt object into multiple prompts. This is primarily used for multi-modal Classifier-Free Guidance (CFG), where it generates the necessary companion requests (like a negative text prompt) and tags them with internal roles (e.g., `cfg_text`). This ensures the upstream LLM generates the needed contextual hidden states for both the conditional and unconditional generations simultaneously.
141-
142-
### `cfg_kv_collect_func` (Optional)
143-
144-
A custom Python function hook for downstream diffusion stages (Stage 1+) to collect, map, and process the KV caches transferred from the companion requests fired by `prompt_expand_func`. It aggregates the hidden condition states cleanly (e.g., binding them as `cfg_text_past_key_values` and `cfg_text_kv_metadata`), allowing the diffusion runtime to perform CFG smoothly without redundantly evaluating text paths on the DiT workers.
145-
146138
### `runtime`
147139

148140
Configuration for disaggregated execution of the stage, controlling how the stage is deployed and executed.

docs/design/architecture_overview.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,6 @@ The framework achieves high performance through several optimization techniques:
9292
* **Quantization:** Supports various quantization implementations including FP8 and AWQ.
9393
* **FusedOps:** Allows for custom and third-party integration.
9494

95-
### Classifier-Free Guidance (CFG) Companion Flow
96-
97-
vLLM-Omni natively models Classifier-Free Guidance (CFG) across disaggregated multi-stage setups via a "companion request" paradigm, eliminating redundant textual/multimodal context computation boundaries:
98-
1. **Prompt Expansion:** In the initial autoregressive (AR) stage, a customized `prompt_expand_func` hook intercepts incoming generation prompts and pairs them directly with negative companion prompts (e.g., a default negative prompt) on the fly, tagging the secondary prompt with a specific internal role (`cfg_text`).
99-
2. **Synchronized KV Cache Transfer:** The AR stage evaluates both the primary and companion sequence batches concurrently. The `OmniConnector` captures these specific structural dependencies and reliably passes the positive and negative outcome KV caches seamlessly across stage boundaries via shared memory or network protocols.
100-
3. **KV Cache Collection & Injection:** Upon reaching the downstream Diffusion (DiT) Engine, an assigned `cfg_kv_collect_func` automatically intercepts the mapped companion caches (`cfg_text_past_key_values`). These auxiliary dependencies are natively gathered and seamlessly bound to the primary generation sequence variables, enabling the DiT Engine to cleanly implement cross-attention CFG guidance over accurate conditioning and unconditioning structures in parallel.
101-
10295
### Flexibility and Usability
10396

10497
vLLM-Omni is designed to be flexible and straightforward for users:

examples/offline_inference/bagel/end2end.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def parse_args():
4949
parser.add_argument("--cfg-text-scale", type=float, default=4.0, help="Text CFG scale (default: 4.0)")
5050
parser.add_argument("--cfg-img-scale", type=float, default=1.5, help="Image CFG scale (default: 1.5)")
5151
parser.add_argument(
52-
"--negative-prompt", type=str, default=None, help="Negative prompt for CFG (default: empty prompt)"
52+
"--negative-prompt", type=str, default=None, help="Negative prompt (not yet supported, reserved for future)"
5353
)
5454

5555
args = parser.parse_args()
@@ -162,8 +162,6 @@ def main():
162162
# text2img
163163
final_prompt_text = f"<|im_start|>{p}<|im_end|>"
164164
prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
165-
if args.negative_prompt is not None:
166-
prompt_dict["negative_prompt"] = args.negative_prompt
167165
formatted_prompts.append(prompt_dict)
168166

169167
params_list = omni.default_sampling_params_list
@@ -172,13 +170,10 @@ def main():
172170
if len(params_list) > 1:
173171
diffusion_params = params_list[1]
174172
diffusion_params.num_inference_steps = args.steps # type: ignore
175-
extra = {
173+
diffusion_params.extra_args = { # type: ignore
176174
"cfg_text_scale": args.cfg_text_scale,
177175
"cfg_img_scale": args.cfg_img_scale,
178176
}
179-
if args.negative_prompt is not None:
180-
extra["negative_prompt"] = args.negative_prompt
181-
diffusion_params.extra_args = extra # type: ignore
182177

183178
omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))
184179

tests/diffusion/test_diffusion_model_runner.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,7 @@ def _make_runner(cache_backend, cache_backend_name: str, enable_cache_dit_summar
5656
enable_cache_dit_summary=enable_cache_dit_summary,
5757
parallel_config=SimpleNamespace(use_hsdp=False),
5858
)
59-
runner.kv_transfer_manager = SimpleNamespace(
60-
receive_kv_cache=lambda req, target_device=None: None,
61-
receive_multi_kv_cache=lambda req, cfg_kv_collect_func=None, target_device=None: None,
62-
)
59+
runner.kv_transfer_manager = SimpleNamespace(receive_kv_cache=lambda req, target_device: None)
6360
return runner
6461

6562

tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
stage_args:
55
- stage_id: 0
66
stage_type: llm
7-
prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
87
runtime:
98
devices: "0"
109
max_batch_size: 1
@@ -40,7 +39,6 @@ stage_args:
4039
to_stage_1: mooncake_connector
4140
- stage_id: 1
4241
stage_type: diffusion
43-
cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
4442
runtime:
4543
devices: "0"
4644
max_batch_size: 1

tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
stage_args:
55
- stage_id: 0
66
stage_type: llm
7-
prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
87
runtime:
98
devices: "0"
109
max_batch_size: 1
@@ -39,7 +38,6 @@ stage_args:
3938

4039
- stage_id: 1
4140
stage_type: diffusion
42-
cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
4341
runtime:
4442
devices: "0"
4543
max_batch_size: 1

tests/e2e/offline_inference/test_bagel_text2img.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@
3737
# "Generated with seed=52, num_inference_steps=15,
3838
# prompt='A futuristic city skyline at twilight, cyberpunk style'"
3939
REFERENCE_PIXELS = [
40-
{"position": (100, 100), "rgb": (49, 96, 134)},
41-
{"position": (400, 50), "rgb": (63, 127, 167)},
42-
{"position": (700, 100), "rgb": (70, 101, 141)},
43-
{"position": (150, 400), "rgb": (115, 90, 150)},
44-
{"position": (512, 512), "rgb": (98, 86, 119)},
45-
{"position": (700, 400), "rgb": (29, 42, 91)},
46-
{"position": (100, 700), "rgb": (47, 50, 88)},
47-
{"position": (400, 700), "rgb": (36, 52, 91)},
48-
{"position": (700, 700), "rgb": (45, 58, 99)},
49-
{"position": (256, 256), "rgb": (62, 94, 135)},
40+
{"position": (100, 100), "rgb": (68, 107, 134)},
41+
{"position": (400, 50), "rgb": (95, 139, 166)},
42+
{"position": (700, 100), "rgb": (99, 122, 151)},
43+
{"position": (150, 400), "rgb": (111, 125, 153)},
44+
{"position": (512, 512), "rgb": (97, 107, 131)},
45+
{"position": (700, 400), "rgb": (48, 64, 98)},
46+
{"position": (100, 700), "rgb": (79, 63, 84)},
47+
{"position": (400, 700), "rgb": (40, 58, 79)},
48+
{"position": (700, 700), "rgb": (60, 75, 103)},
49+
{"position": (256, 256), "rgb": (97, 128, 156)},
5050
]
5151

5252
# Maximum allowed difference per color channel
@@ -80,10 +80,6 @@ def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_st
8080
params_list[0].max_tokens = max_tokens # type: ignore
8181
if len(params_list) > 1:
8282
params_list[1].num_inference_steps = num_inference_steps # type: ignore
83-
params_list[1].extra_args = { # type: ignore
84-
"cfg_text_scale": 4.0,
85-
"cfg_img_scale": 1.5,
86-
}
8783
return params_list
8884

8985

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
E2E offline tests for Qwen3-Omni-MoE with PD (Prefill-Decode) disaggregation.
3+
4+
Tests both text-only and audio output modalities through the 4-stage
5+
PD pipeline: Prefill -> Decode -> Talker -> Code2Wav.
6+
"""
7+
8+
import os
9+
10+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
11+
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
12+
13+
from pathlib import Path
14+
15+
import pytest
16+
17+
from tests.conftest import (
18+
generate_synthetic_video,
19+
)
20+
from tests.utils import hardware_test
21+
22+
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
23+
24+
# PD disaggregation CI stage config (requires 3x GPUs)
25+
stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_pd_ci.yaml")]
26+
27+
# Create parameter combinations for model and stage config
28+
test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
29+
30+
31+
def get_question(prompt_type="video"):
32+
prompts = {
33+
"video": "Describe the video briefly.",
34+
"text": "What is the capital of China? Answer in 20 words.",
35+
}
36+
return prompts.get(prompt_type, prompts["video"])
37+
38+
39+
@pytest.mark.core_model
40+
@pytest.mark.omni
41+
@hardware_test(res={"cuda": "H100"}, num_cards=3)
42+
@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
43+
def test_pd_text_only(omni_runner, omni_runner_handler) -> None:
44+
"""Test PD disaggregation with text-only output (no talker/code2wav)."""
45+
request_config = {
46+
"prompts": get_question("text"),
47+
"modalities": ["text"],
48+
}
49+
omni_runner_handler.send_request(request_config)
50+
51+
52+
@pytest.mark.core_model
53+
@pytest.mark.omni
54+
@hardware_test(res={"cuda": "H100"}, num_cards=3)
55+
@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
56+
def test_pd_video_to_audio(omni_runner, omni_runner_handler) -> None:
57+
"""Test PD disaggregation with video input and audio output
58+
through the full 4-stage pipeline."""
59+
video = generate_synthetic_video(224, 224, 300)["np_array"]
60+
61+
request_config = {
62+
"prompts": get_question("video"),
63+
"videos": video,
64+
"modalities": ["audio"],
65+
}
66+
omni_runner_handler.send_request(request_config)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
E2E online serving tests for Qwen3-Omni-MoE with PD (Prefill-Decode) disaggregation.
3+
4+
Tests both text-only and audio output modalities via the OpenAI-compatible API
5+
through the 4-stage PD pipeline: Prefill -> Decode -> Talker -> Code2Wav.
6+
"""
7+
8+
import os
9+
from pathlib import Path
10+
11+
import pytest
12+
13+
from tests.conftest import (
14+
dummy_messages_from_mix_data,
15+
generate_synthetic_audio,
16+
generate_synthetic_image,
17+
generate_synthetic_video,
18+
)
19+
from tests.utils import hardware_test
20+
21+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
22+
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
23+
24+
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
25+
26+
# PD disaggregation CI stage config (requires 3x GPUs)
27+
stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_pd_ci.yaml")]
28+
29+
# Create parameter combinations for model and stage config
30+
test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
31+
32+
33+
def get_system_prompt():
34+
return {
35+
"role": "system",
36+
"content": [
37+
{
38+
"type": "text",
39+
"text": (
40+
"You are Qwen, a virtual human developed by the Qwen Team, "
41+
"Alibaba Group, capable of perceiving auditory and visual inputs, "
42+
"as well as generating text and speech."
43+
),
44+
}
45+
],
46+
}
47+
48+
49+
def get_prompt(prompt_type="text_only"):
50+
prompts = {
51+
"text_only": "What is the capital of China? Answer in 20 words.",
52+
"mix": "What is recited in the audio? What is in this image? Describe the video briefly.",
53+
}
54+
return prompts.get(prompt_type, prompts["text_only"])
55+
56+
57+
@pytest.mark.advanced_model
58+
@pytest.mark.core_model
59+
@pytest.mark.omni
60+
@hardware_test(res={"cuda": "H100"}, num_cards=3)
61+
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
62+
def test_pd_text_to_text(omni_server, openai_client) -> None:
63+
"""
64+
Test PD disaggregation with text-only output via OpenAI API.
65+
Deploy Setting: PD separation yaml
66+
Input Modal: text
67+
Output Modal: text
68+
Input Setting: stream=False
69+
Datasets: single request
70+
"""
71+
messages = dummy_messages_from_mix_data(
72+
system_prompt=get_system_prompt(),
73+
content_text=get_prompt("text_only"),
74+
)
75+
76+
request_config = {
77+
"model": omni_server.model,
78+
"messages": messages,
79+
"stream": False,
80+
"modalities": ["text"],
81+
"key_words": {"text": ["beijing"]},
82+
}
83+
84+
openai_client.send_request(request_config)
85+
86+
87+
@pytest.mark.advanced_model
88+
@pytest.mark.core_model
89+
@pytest.mark.omni
90+
@hardware_test(res={"cuda": "H100"}, num_cards=3)
91+
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
92+
def test_pd_mix_to_text_audio(omni_server, openai_client) -> None:
93+
"""
94+
Test PD disaggregation with multi-modal input and text+audio output via OpenAI API.
95+
Deploy Setting: PD separation yaml
96+
Input Modal: text + audio + video + image
97+
Output Modal: text + audio
98+
Input Setting: stream=True
99+
Datasets: single request
100+
"""
101+
video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
102+
image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
103+
audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}"
104+
messages = dummy_messages_from_mix_data(
105+
system_prompt=get_system_prompt(),
106+
video_data_url=video_data_url,
107+
image_data_url=image_data_url,
108+
audio_data_url=audio_data_url,
109+
content_text=get_prompt("mix"),
110+
)
111+
112+
request_config = {
113+
"model": omni_server.model,
114+
"messages": messages,
115+
"stream": True,
116+
"key_words": {
117+
"audio": ["water", "chirping", "crackling", "rain"],
118+
"image": ["square", "quadrate"],
119+
},
120+
}
121+
122+
openai_client.send_request(request_config)

0 commit comments

Comments
 (0)