diff --git a/tests/e2e/offline_inference/test_qwen3_omni_pd.py b/tests/e2e/offline_inference/test_qwen3_omni_pd.py new file mode 100644 index 0000000000..6bd776ee91 --- /dev/null +++ b/tests/e2e/offline_inference/test_qwen3_omni_pd.py @@ -0,0 +1,66 @@ +""" +E2E offline tests for Qwen3-Omni-MoE with PD (Prefill-Decode) disaggregation. + +Tests both text-only and audio output modalities through the 4-stage +PD pipeline: Prefill -> Decode -> Talker -> Code2Wav. +""" + +import os + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +from pathlib import Path + +import pytest + +from tests.conftest import ( + generate_synthetic_video, +) +from tests.utils import hardware_test + +models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] + +# PD disaggregation CI stage config (requires 3x GPUs) +stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_pd_ci.yaml")] + +# Create parameter combinations for model and stage config +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + + +def get_question(prompt_type="video"): + prompts = { + "video": "Describe the video briefly.", + "text": "What is the capital of China? Answer in 20 words.", + } + return prompts.get(prompt_type, prompts["video"]) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=3) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_pd_text_only(omni_runner, omni_runner_handler) -> None: + """Test PD disaggregation with text-only output (no talker/code2wav).""" + request_config = { + "prompts": get_question("text"), + "modalities": ["text"], + } + omni_runner_handler.send_request(request_config) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=3) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_pd_video_to_audio(omni_runner, omni_runner_handler) -> None: + """Test PD disaggregation with video input and audio output + through the full 4-stage pipeline.""" + video = generate_synthetic_video(224, 224, 300)["np_array"] + + request_config = { + "prompts": get_question("video"), + "videos": video, + "modalities": ["audio"], + } + omni_runner_handler.send_request(request_config) diff --git a/tests/e2e/online_serving/test_qwen3_omni_pd.py b/tests/e2e/online_serving/test_qwen3_omni_pd.py new file mode 100644 index 0000000000..19a7c2c569 --- /dev/null +++ b/tests/e2e/online_serving/test_qwen3_omni_pd.py @@ -0,0 +1,122 @@ +""" +E2E online serving tests for Qwen3-Omni-MoE with PD (Prefill-Decode) disaggregation. + +Tests both text-only and audio output modalities via the OpenAI-compatible API +through the 4-stage PD pipeline: Prefill -> Decode -> Talker -> Code2Wav. +""" + +import os +from pathlib import Path + +import pytest + +from tests.conftest import ( + dummy_messages_from_mix_data, + generate_synthetic_audio, + generate_synthetic_image, + generate_synthetic_video, +) +from tests.utils import hardware_test + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] + +# PD disaggregation CI stage config (requires 3x GPUs) +stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_pd_ci.yaml")] + +# Create parameter combinations for model and stage config +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + + +def get_system_prompt(): + return { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "You are Qwen, a virtual human developed by the Qwen Team, " + "Alibaba Group, capable of perceiving auditory and visual inputs, " + "as well as generating text and speech." + ), + } + ], + } + + +def get_prompt(prompt_type="text_only"): + prompts = { + "text_only": "What is the capital of China? Answer in 20 words.", + "mix": "What is recited in the audio? What is in this image? Describe the video briefly.", + } + return prompts.get(prompt_type, prompts["text_only"]) + + +@pytest.mark.advanced_model +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=3) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_pd_text_to_text(omni_server, openai_client) -> None: + """ + Test PD disaggregation with text-only output via OpenAI API. + Deploy Setting: PD separation yaml + Input Modal: text + Output Modal: text + Input Setting: stream=False + Datasets: single request + """ + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + content_text=get_prompt("text_only"), + ) + + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": False, + "modalities": ["text"], + "key_words": {"text": ["beijing"]}, + } + + openai_client.send_request(request_config) + + +@pytest.mark.advanced_model +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=3) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_pd_mix_to_text_audio(omni_server, openai_client) -> None: + """ + Test PD disaggregation with multi-modal input and text+audio output via OpenAI API. + Deploy Setting: PD separation yaml + Input Modal: text + audio + video + image + Output Modal: text + audio + Input Setting: stream=True + Datasets: single request + """ + video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}" + image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}" + audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}" + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + video_data_url=video_data_url, + image_data_url=image_data_url, + audio_data_url=audio_data_url, + content_text=get_prompt("mix"), + ) + + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": True, + "key_words": { + "audio": ["water", "chirping", "crackling", "rain"], + "image": ["square", "quadrate"], + }, + } + + openai_client.send_request(request_config) diff --git a/tests/e2e/stage_configs/qwen3_omni_pd_ci.yaml b/tests/e2e/stage_configs/qwen3_omni_pd_ci.yaml new file mode 100644 index 0000000000..7f16984a40 --- /dev/null +++ b/tests/e2e/stage_configs/qwen3_omni_pd_ci.yaml @@ -0,0 +1,184 @@ +# Stage config for Qwen3-Omni-MoE with PD (Prefill-Decode) disaggregation +# CI variant: uses load_format: dummy so tests can run without real weights. +# +# Stage 0: Thinker Prefill (prompt processing, KV producer) +# Stage 1: Thinker Decode (token generation, KV consumer) +# Stage 2: Talker (text embeddings -> RVQ codec codes) +# Stage 3: Code2Wav (RVQ codes -> audio waveform) +# +# Requires 3x GPUs: GPU 0 = prefill, GPU 1 = decode, GPU 2 = talker + code2wav +# Both prefill and decode stages MUST use the same tensor_parallel_size. + +async_chunk: false +stage_args: + - stage_id: 0 + stage_type: llm + is_prefill_only: true + runtime: + devices: "0" + max_batch_size: 5 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + max_model_len: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 1 + load_format: dummy + kv_transfer_config: + kv_connector: "MooncakeConnector" + kv_role: "kv_producer" + kv_rank: 0 + kv_parallel_size: 2 + engine_id: "omni-thinker-prefill" + kv_connector_extra_config: + mooncake_bootstrap_port: 25201 + final_output: false + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 100 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: llm + is_decode_only: true + runtime: + devices: "1" + max_batch_size: 5 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + max_model_len: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 1 + load_format: dummy + kv_transfer_config: + kv_connector: "MooncakeConnector" + kv_role: "kv_consumer" + kv_rank: 1 + kv_parallel_size: 2 + engine_id: "omni-thinker-decode" + kv_connector_extra_config: + mooncake_bootstrap_port: 25202 + engine_input_source: [0] + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 100 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 2 + stage_type: llm + runtime: + devices: "2" + max_batch_size: 5 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + max_num_batched_tokens: 32768 + max_model_len: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + load_format: dummy + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 1000 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 3 + stage_type: llm + runtime: + devices: "2" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 100000 + hf_config_name: thinker_config + async_scheduling: false + load_format: dummy + engine_input_source: [2] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2000 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Runtime edges +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + + edges: + - from: 0 + to: 1 + window_size: -1 + - from: 1 + to: 2 + window_size: -1 + - from: 2 + to: 3 + window_size: -1 diff --git a/tests/entrypoints/test_pd_disaggregation.py b/tests/entrypoints/test_pd_disaggregation.py new file mode 100644 index 0000000000..a855ef1d13 --- /dev/null +++ b/tests/entrypoints/test_pd_disaggregation.py @@ -0,0 +1,1468 @@ +"""Unit tests for PD (Prefill-Decode) disaggregation in the Omni orchestrator. + +Tests the PD detection, validation, config parsing, sampling param +preparation, and routing logic added by the PD disaggregation feature +(issue #1188). All tests run without GPU by using the same mocking +infrastructure as test_omni_llm.py. +""" + +import uuid +import warnings +from queue import Empty, Queue +from typing import Any +from unittest.mock import MagicMock + +import pytest +from vllm import SamplingParams + +from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK + +# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies. +warnings.filterwarnings( + "ignore", + message=r"builtin type SwigPy.*has no __module__ attribute", + category=DeprecationWarning, +) + + +# --------------------------------------------------------------------------- +# Fake helpers (same pattern as test_omni_llm.py) +# --------------------------------------------------------------------------- + + +class _FakeEngineArgs(dict): + """Fake engine args that supports both attribute and dict access.""" + + def __init__(self, args_dict: dict[str, Any]): + super().__init__(args_dict) + if "model_stage" not in self: + self["model_stage"] = None + if "engine_output_type" not in self: + self["engine_output_type"] = None + for key, value in self.items(): + setattr(self, key, value) + + +class _FakeStageConfig: + def __init__(self, config_dict: dict[str, Any]): + engine_args_dict = config_dict.get("engine_args", {}) + self.engine_args = _FakeEngineArgs(engine_args_dict) + self.final_output = config_dict.get("final_output", False) + self.final_output_type = config_dict.get("final_output_type", None) + self.stage_id = config_dict.get("stage_id", 0) + self.is_prefill_only = config_dict.get("is_prefill_only", False) + self.is_decode_only = config_dict.get("is_decode_only", False) + self.engine_input_source = config_dict.get("engine_input_source", []) + self.is_comprehension = config_dict.get("is_comprehension", False) + self._config_dict = config_dict + + +class _FakeQueue: + def __init__(self, maxsize=0): + self._queue = Queue(maxsize=maxsize) + + def put(self, item): + self._queue.put(item) + + def put_nowait(self, item): + self._queue.put_nowait(item) + + def get(self): + return self._queue.get() + + def get_nowait(self): + return self._queue.get_nowait() + + def empty(self): + return self._queue.empty() + + +class _FakeStage: + """Lightweight stage stub with PD disaggregation flag support.""" + + def __init__(self, config, stage_init_timeout: int = 300): + if isinstance(config, dict): + config = _FakeStageConfig(config) + self.config = config + self.stage_config = config + self.engine = None + self.engine_outputs = None + self.stage_id = getattr(config, "stage_id", 0) + self.engine_args = config.engine_args + self.model_stage = getattr(config.engine_args, "model_stage", None) + self.stage_type = "llm" + self.default_sampling_params = SamplingParams(temperature=1.0) + self.final_output = config.final_output if hasattr(config, "final_output") else False + self.final_output_type = getattr(config, "final_output_type", None) + self.is_prefill_only = getattr(config, "is_prefill_only", False) + self.is_decode_only = getattr(config, "is_decode_only", False) + self.engine_input_source = getattr(config, "engine_input_source", []) + self.is_comprehension = getattr(config, "is_comprehension", False) + processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"]) + self._processed_input = processed_input + self._in_q = None + self._out_q = None + self._proc = None + self._stage_init_timeout = max(0, int(stage_init_timeout)) + + def attach_queues(self, in_q, out_q): + self._in_q = in_q + self._out_q = out_q + + def init_stage_worker( + self, model: str, *, is_async=False, shm_threshold_bytes=65536, ctx=None, batch_timeout=10, **kwargs + ): + self._proc = MagicMock() + self._proc.start = MagicMock() + self._proc.join = MagicMock() + self._proc.is_alive = MagicMock(return_value=False) + self._proc.terminate = MagicMock() + if self._out_q is not None: + try: + self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id}) + except Exception: + pass + + def stop_stage_worker(self): + if self._in_q is not None: + try: + self._in_q.put_nowait(SHUTDOWN_TASK) + except Exception: + pass + + def submit(self, payload: dict[str, Any]): + if self._in_q is not None: + self._in_q.put(payload) + + def try_collect(self) -> Any: + if self._out_q is None: + return None + try: + return self._out_q.get_nowait() + except Empty: + return None + + def set_engine_outputs(self, outputs): + self.engine_outputs = outputs + + def process_engine_inputs(self, stage_list, prompts): + return self._processed_input + + +# --------------------------------------------------------------------------- +# Shared mock setup helpers +# --------------------------------------------------------------------------- + + +def _setup_engine_mocks(monkeypatch): + fake_engine = MagicMock() + fake_engine.tokenizer = MagicMock() + fake_engine.log_stats = False + fake_engine.vllm_config = MagicMock() + fake_engine.vllm_config.model_config = MagicMock() + fake_engine.vllm_config.model_config.io_processor_plugin = None + fake_engine.get_supported_tasks = MagicMock(return_value=[]) + fake_engine.model_config = MagicMock() + fake_engine.model_config.io_processor_plugin = None + fake_registry = MagicMock() + fake_registry.resolve_model_cls = MagicMock(return_value=(MagicMock(), "test_arch")) + fake_engine.model_config.registry = fake_registry + fake_engine.vllm_config.model_config.registry = fake_registry + + monkeypatch.setattr( + "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args", + lambda **kw: fake_engine, + raising=False, + ) + + class FakeModelClass: + pass + + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils.get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils._get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + monkeypatch.setattr( + "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls", + lambda model_cls: model_cls, + raising=False, + ) + monkeypatch.setattr( + "vllm.multimodal.cache._enable_processor_cache", + lambda model_config, mm_registry: False, + raising=False, + ) + monkeypatch.setattr( + "vllm.plugins.io_processors.get_io_processor", + lambda vllm_config, io_processor_plugin: None, + raising=False, + ) + + +def _setup_multiprocessing_mocks(monkeypatch): + import multiprocessing as mp + + fake_process_class = MagicMock() + fake_process_instance = MagicMock() + fake_process_instance.start = MagicMock() + fake_process_instance.join = MagicMock() + fake_process_instance.is_alive = MagicMock(return_value=False) + fake_process_instance.terminate = MagicMock() + fake_process_class.return_value = fake_process_instance + + fake_ctx = MagicMock() + fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize) + fake_ctx.Process = fake_process_class + + monkeypatch.setattr(mp, "get_context", lambda method: fake_ctx, raising=False) + monkeypatch.setattr(mp, "Process", fake_process_class, raising=False) + + +def _setup_ipc_mocks(monkeypatch): + def _fake_encode(obj, threshold, obj_key, shm_key): + return {obj_key: obj} + + def _fake_load(result, obj_key, shm_key): + return result.get(obj_key) + + def _fake_set(obj): + return str(obj).encode() + + monkeypatch.setattr("vllm_omni.entrypoints.omni._encode", _fake_encode, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._load", _fake_load, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._set", _fake_set, raising=False) + + +def _setup_log_mocks(monkeypatch): + class _FakeOrchestratorAggregator: + def __init__(self, num_stages, enable_stats, wall_start_ts, final_stage_id_for_e2e=None): + self.num_stages = num_stages + self.enable_stats = enable_stats + self.stage_first_ts = [None] * num_stages + self.stage_last_ts = [None] * num_stages + self.stage_total_tokens = [0] * num_stages + self.accumulated_gen_time_ms = {} + self.e2e_done = set() + self.e2e_count = 0 + self.e2e_total_ms = 0.0 + + def on_stage_metrics(self, stage_id, req_id, metrics, final_output_type=None): + pass + + def on_finalize_request(self, stage_id, req_id, start_ts): + self.e2e_done.add(req_id) + + def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm): + pass + + def accumulate_diffusion_metrics(self, stage_type, req_id, engine_outputs): + pass + + def record_audio_generated_frames(self, output, stage_id, req_id): + pass + + def stage_postprocess_timer(self, stage_id, req_id): + from contextlib import contextmanager + + @contextmanager + def _noop(): + yield + + return _noop() + + def build_and_log_summary(self): + return "Fake summary" + + monkeypatch.setattr( + "vllm_omni.entrypoints.omni.OrchestratorAggregator", + _FakeOrchestratorAggregator, + raising=False, + ) + + +def _clear_modules(): + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + +@pytest.fixture(autouse=True) +def mock_get_config(monkeypatch): + """Auto-mock get_config and related model loading functions.""" + import sys + + fake_tokenizer = MagicMock() + fake_tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + fake_tokenizer.decode = MagicMock(return_value="test") + + def _mock_init_tokenizer_from_configs(model_config=None, **kwargs): + return fake_tokenizer + + monkeypatch.setattr( + "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs", + _mock_init_tokenizer_from_configs, + raising=False, + ) + tokenizer_module_path = "vllm.transformers_utils.tokenizer" + if tokenizer_module_path in sys.modules: + setattr(sys.modules[tokenizer_module_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None): + if prompt_token_ids is not None: + if isinstance(prompt_token_ids, list): + return len(prompt_token_ids) + return 10 + + monkeypatch.setattr( + "vllm.utils.length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds, raising=False + ) + monkeypatch.setattr( + "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + raising=False, + ) + + processor_module_path = "vllm_omni.engine.input_processor" + if processor_module_path in sys.modules: + setattr( + sys.modules[processor_module_path], + "length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + ) + + monkeypatch.setattr( + "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", _mock_init_tokenizer_from_configs, raising=False + ) + async_omni_path = "vllm_omni.entrypoints.async_omni" + if async_omni_path in sys.modules: + setattr(sys.modules[async_omni_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + fake_hf_config = MagicMock() + fake_hf_config.model_type = "qwen2_5_omni" + + monkeypatch.setattr( + "vllm.transformers_utils.config.get_config", lambda model, **kwargs: fake_hf_config, raising=False + ) + monkeypatch.setattr("vllm_omni.entrypoints.utils.get_config", lambda model, **kwargs: fake_hf_config, raising=False) + + def _mock_cached_file(path_or_repo_id, *args, **kwargs): + import os + import tempfile + + fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json") + if not os.path.exists(fake_config_file): + with open(fake_config_file, "w") as f: + f.write('{"model_type": "qwen2_5_omni"}') + return fake_config_file + + monkeypatch.setattr("transformers.utils.hub.cached_file", _mock_cached_file, raising=False) + monkeypatch.setattr( + "transformers.utils.hub.cached_files", + lambda path_or_repo_id, filenames, **kwargs: ( + [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None + ), + raising=False, + ) + + +# --------------------------------------------------------------------------- +# Helper to build an Omni instance with PD stage configs +# --------------------------------------------------------------------------- + + +def _make_pd_omni(monkeypatch, stage_configs, *, extra_setup=None): + """Create an Omni instance whose stage_list consists of _FakeStage objects + built from *stage_configs* (list of dicts). + + Mocks the full init chain so no real model download, connector init, + or stage process creation occurs. + """ + _clear_modules() + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + configs = [_FakeStageConfig(c) for c in stage_configs] + + # Mock load_and_resolve_stage_configs (the current entry point used by + # OmniBase._resolve_stage_configs). It returns (config_path, stage_configs). + def _fake_load_and_resolve(model, stage_configs_path=None, kwargs=None, **kw): + return ("fake_config_path", configs) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", + _fake_load_and_resolve, + raising=False, + ) + + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + # Mock load_and_resolve_stage_configs on the omni module (it's imported + # from utils at the top of omni.py). + monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_load_and_resolve) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock omni_snapshot_download so it doesn't try to download a real model + monkeypatch.setattr(omni_module, "omni_snapshot_download", lambda model_id: model_id) + + # Build a connectors dict that covers all edges so that + # try_send_via_connector is always called (it sends data to the + # next stage's input queue). + num_stages = len(configs) + fake_connectors = {} + for i in range(num_stages - 1): + fake_connectors[(str(i), str(i + 1))] = MagicMock() + + # Mock initialize_orchestrator_connectors so it doesn't parse real configs + monkeypatch.setattr( + omni_module, + "initialize_orchestrator_connectors", + lambda *args, **kwargs: (None, fake_connectors), + ) + + # Mock try_send_via_connector to directly submit to the stage queue + # (the real version uses OmniConnector IPC; in tests we just put the + # payload into the stage's input queue directly). + def _fake_try_send( + connector, + stage_id, + next_stage_id, + req_id, + next_inputs, + sampling_params, + original_prompt, + next_stage_queue_submit_fn, + metrics, + ): + task = { + "request_id": req_id, + "engine_inputs": next_inputs, + "sampling_params": sampling_params, + } + next_stage_queue_submit_fn(task) + return True + + monkeypatch.setattr(omni_module, "try_send_via_connector", _fake_try_send) + + # Mock _start_stages to create fake queues (the real version uses ZMQ + # sockets, multiprocessing contexts, and spawns stage workers). + def _fake_start_stages(self, model): + for stage in self.stage_list: + in_q = _FakeQueue() + out_q = _FakeQueue() + self._stage_in_queues.append(in_q) + self._stage_out_queues.append(out_q) + stage.attach_queues(in_q, out_q) + + from vllm_omni.entrypoints.omni import OmniBase + + monkeypatch.setattr(OmniBase, "_start_stages", _fake_start_stages) + monkeypatch.setattr(OmniBase, "_wait_for_stages_ready", lambda self, timeout=120: None) + + if extra_setup: + extra_setup(monkeypatch, omni_module) + + from vllm_omni.entrypoints.omni import Omni + + return Omni(model="any", init_timeout=1) + + +# --------------------------------------------------------------------------- +# Stage config templates +# --------------------------------------------------------------------------- + + +def _prefill_stage_cfg(stage_id=0, **overrides): + cfg = { + "stage_id": stage_id, + "engine_args": { + "model_stage": "thinker", + "kv_transfer_config": { + "kv_connector": "MooncakeConnector", + "kv_role": "kv_producer", + "kv_rank": 0, + "kv_parallel_size": 2, + "engine_id": "omni-thinker-prefill", + "kv_connector_extra_config": {"mooncake_bootstrap_port": 25201}, + }, + }, + "is_prefill_only": True, + "final_output": False, + "is_comprehension": True, + } + cfg.update(overrides) + return cfg + + +def _decode_stage_cfg(stage_id=1, engine_input_source=None, **overrides): + cfg = { + "stage_id": stage_id, + "engine_args": { + "model_stage": "thinker", + "kv_transfer_config": { + "kv_connector": "MooncakeConnector", + "kv_role": "kv_consumer", + "kv_rank": 1, + "kv_parallel_size": 2, + "engine_id": "omni-thinker-decode", + "kv_connector_extra_config": {"mooncake_bootstrap_port": 25202}, + }, + }, + "is_decode_only": True, + "engine_input_source": engine_input_source if engine_input_source is not None else [0], + "final_output": True, + "final_output_type": "text", + "is_comprehension": True, + } + cfg.update(overrides) + return cfg + + +def _talker_stage_cfg(stage_id=2, engine_input_source=None, **overrides): + cfg = { + "stage_id": stage_id, + "engine_args": {"model_stage": "talker"}, + "engine_input_source": engine_input_source if engine_input_source is not None else [1], + "final_output": False, + } + cfg.update(overrides) + return cfg + + +def _code2wav_stage_cfg(stage_id=3, engine_input_source=None, **overrides): + cfg = { + "stage_id": stage_id, + "engine_args": {"model_stage": "code2wav"}, + "engine_input_source": engine_input_source if engine_input_source is not None else [2], + "final_output": True, + "final_output_type": "audio", + } + cfg.update(overrides) + return cfg + + +# =================================================================== +# Tests: PD pair detection +# =================================================================== + + +class TestDetectPDSeparation: + """Tests for Omni._detect_pd_separation().""" + + def test_detects_pd_pair(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + ], + ) + assert omni._pd_separation_pair == (0, 1) + + def test_no_pd_pair_without_flags(self, monkeypatch): + """Normal (non-PD) pipeline has no PD pair.""" + omni = _make_pd_omni( + monkeypatch, + [ + { + "stage_id": 0, + "engine_args": {"model_stage": "thinker"}, + "final_output": True, + "final_output_type": "text", + }, + { + "stage_id": 1, + "engine_args": {"model_stage": "talker"}, + "engine_input_source": [0], + "final_output": True, + "final_output_type": "audio", + }, + ], + ) + assert omni._pd_separation_pair is None + + def test_detects_pd_pair_in_4_stage_pipeline(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + _talker_stage_cfg(stage_id=2, engine_input_source=[1]), + _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]), + ], + ) + assert omni._pd_separation_pair == (0, 1) + + def test_pd_pair_uses_stage_id_for_input_source(self, monkeypatch): + """engine_input_source references stage_id, not list index.""" + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=10), + _decode_stage_cfg(stage_id=20, engine_input_source=[10]), + ], + ) + assert omni._pd_separation_pair == (0, 1) + + +# =================================================================== +# Tests: PD config validation +# =================================================================== + + +class TestValidatePDConfig: + """Tests for Omni._validate_pd_separation_config().""" + + def test_valid_config_passes(self, monkeypatch): + """Valid PD config should not raise.""" + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + # If we got here without error, validation passed + assert omni._pd_separation_pair == (0, 1) + + def test_mismatched_connector_raises(self, monkeypatch): + """Different kv_connector types should raise ValueError.""" + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["kv_transfer_config"]["kv_connector"] = "NixlConnector" + + with pytest.raises(ValueError, match="connector mismatch"): + _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg]) + + def test_wrong_prefill_role_raises(self, monkeypatch): + """Prefill with kv_consumer role should raise.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_consumer" + + with pytest.raises(ValueError, match="kv_role must be"): + _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])]) + + def test_wrong_decode_role_raises(self, monkeypatch): + """Decode with kv_producer role should raise.""" + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_producer" + + with pytest.raises(ValueError, match="kv_role must be"): + _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg]) + + def test_missing_kv_transfer_config_raises(self, monkeypatch): + """Missing kv_transfer_config should raise.""" + prefill_cfg = _prefill_stage_cfg() + del prefill_cfg["engine_args"]["kv_transfer_config"] + + with pytest.raises(ValueError, match="kv_transfer_config"): + _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])]) + + def test_mismatched_buffer_device_raises(self, monkeypatch): + """Mismatched kv_buffer_device should raise.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cuda" + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cpu" + + with pytest.raises(ValueError, match="kv_buffer_device mismatch"): + _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg]) + + +# =================================================================== +# Tests: Connector info extraction +# =================================================================== + + +class TestGetPDConnectorInfo: + """Tests for Omni._get_pd_connector_info().""" + + def test_extracts_engine_id(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + info = omni._pd_connector_info + assert info is not None + assert info["prefill_engine_id"] == "omni-thinker-prefill" + + def test_extracts_bootstrap_addr_for_mooncake(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + info = omni._pd_connector_info + assert "prefill_bootstrap_addr" in info + assert info["prefill_bootstrap_addr"] == "127.0.0.1:25201" + + def test_none_for_non_pd_pipeline(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + {"stage_id": 0, "engine_args": {}, "final_output": True, "final_output_type": "text"}, + ], + ) + assert omni._pd_connector_info is None + + +# =================================================================== +# Tests: Prefill sampling params preparation +# =================================================================== + + +class TestPreparePrefillSamplingParams: + """Tests for Omni._prepare_prefill_sampling_params().""" + + def test_sets_max_tokens_to_1(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048) + result = omni._prepare_prefill_sampling_params("req-1", sp) + + assert result.max_tokens == 1 + assert result is not sp # should be cloned + + def test_injects_kv_transfer_params(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048) + result = omni._prepare_prefill_sampling_params("req-1", sp) + + kv_params = result.extra_args["kv_transfer_params"] + assert kv_params["do_remote_decode"] is True + assert kv_params["do_remote_prefill"] is False + assert kv_params["transfer_id"] == "xfer-req-1" + + def test_preserves_existing_extra_args(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, extra_args={"custom_key": "value"}) + result = omni._prepare_prefill_sampling_params("req-1", sp) + + assert result.extra_args["custom_key"] == "value" + assert "kv_transfer_params" in result.extra_args + + def test_does_not_mutate_original(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048) + _ = omni._prepare_prefill_sampling_params("req-1", sp) + + assert sp.max_tokens == 2048 + assert sp.extra_args is None + + +# =================================================================== +# Tests: Sampling params auto-duplication for PD split +# =================================================================== + + +class TestSamplingParamsAutoDuplication: + """When user provides N-1 sampling params (for logical stages), the + orchestrator should auto-duplicate the thinker params for the decode stage. + """ + + def test_auto_duplicates_for_4_stage_pipeline(self, monkeypatch): + """User provides 3 params for 4 physical stages -> auto-insert decode params.""" + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000001") + + def _extra_setup(mp, omni_module): + mp.setattr(uuid, "uuid4", lambda: test_uuid) + mp.setattr(omni_module, "uuid", uuid) + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + _talker_stage_cfg(stage_id=2, engine_input_source=[1]), + _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]), + ], + extra_setup=_extra_setup, + ) + + assert omni._pd_separation_pair == (0, 1) + assert len(omni.stage_list) == 4 + + # Simulate outputs for all stages + expected_rid = f"0_{test_uuid}" + for i in range(4): + omni.stage_list[i]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [MagicMock(request_id=expected_rid, outputs=[MagicMock(token_ids=[1, 2])])], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + # Provide 3 params (one less than 4 stages) - should auto-duplicate + sp_thinker = SamplingParams(temperature=0.4, max_tokens=2048) + sp_talker = SamplingParams(temperature=0.9, max_tokens=4096) + sp_code2wav = SamplingParams(temperature=0.0, max_tokens=65536) + + # This should NOT raise ValueError about param count mismatch + outputs = omni.generate( + prompts=["hello"], + sampling_params_list=[sp_thinker, sp_talker, sp_code2wav], + ) + assert isinstance(outputs, list) + + +# =================================================================== +# Tests: KV transfer params normalization +# =================================================================== + + +class TestNormalizeKVTransferParams: + def test_dict_passthrough(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + d = {"transfer_id": "test", "do_remote_decode": True} + assert omni._normalize_kv_transfer_params(d) is d + + def test_none_returns_none(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + assert omni._normalize_kv_transfer_params(None) is None + + def test_dataclass_to_dict(self, monkeypatch): + from dataclasses import dataclass + + @dataclass + class FakeKVParams: + transfer_id: str = "test" + do_remote_decode: bool = True + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + result = omni._normalize_kv_transfer_params(FakeKVParams()) + assert isinstance(result, dict) + assert result["transfer_id"] == "test" + + +# =================================================================== +# Tests: _kv_cfg_to_dict +# =================================================================== + + +class TestKvCfgToDict: + def test_dict_passthrough(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + d = {"kv_connector": "MooncakeConnector"} + assert omni._kv_cfg_to_dict(d) is d + + def test_none_returns_empty(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + assert omni._kv_cfg_to_dict(None) == {} + + def test_dataclass_converted(self, monkeypatch): + from dataclasses import dataclass + + @dataclass + class FakeCfg: + kv_connector: str = "TestConnector" + kv_role: str = "kv_producer" + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + result = omni._kv_cfg_to_dict(FakeCfg()) + assert result["kv_connector"] == "TestConnector" + assert result["kv_role"] == "kv_producer" + + +# =================================================================== +# Tests: PD routing in scheduling loop +# =================================================================== + + +class TestPDRouting: + """Test that the scheduling loop correctly routes requests from + prefill to decode stage with proper kv_transfer_params. + """ + + def test_prefill_stage_receives_max_tokens_1(self, monkeypatch): + """Stage 0 (prefill) should receive max_tokens=1.""" + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000002") + + def _extra_setup(mp, omni_module): + mp.setattr(uuid, "uuid4", lambda: test_uuid) + mp.setattr(omni_module, "uuid", uuid) + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + ], + extra_setup=_extra_setup, + ) + + expected_rid = f"0_{test_uuid}" + + # Put stage outputs in both queues + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [MagicMock(request_id=expected_rid, outputs=[MagicMock(token_ids=[1])])], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [MagicMock(request_id=expected_rid, outputs=[MagicMock(token_ids=[1, 2, 3])])], + "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0}, + } + ) + + sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)] + omni.generate(prompts=["hello"], sampling_params_list=sp_list) + + # Check what was submitted to stage 0's input queue + # (skip the stage_ready message first) + task = omni.stage_list[0]._in_q.get_nowait() + assert task["sampling_params"].max_tokens == 1 + kv_params = task["sampling_params"].extra_args["kv_transfer_params"] + assert kv_params["do_remote_decode"] is True + assert kv_params["do_remote_prefill"] is False + assert kv_params["transfer_id"] == f"xfer-{expected_rid}" + + def test_decode_stage_receives_original_prompt(self, monkeypatch): + """Decode stage should get the original prompt (not processed outputs).""" + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000003") + + def _extra_setup(mp, omni_module): + mp.setattr(uuid, "uuid4", lambda: test_uuid) + mp.setattr(omni_module, "uuid", uuid) + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + ], + extra_setup=_extra_setup, + ) + + expected_rid = f"0_{test_uuid}" + original_prompt = "test prompt for PD" + + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [MagicMock(request_id=expected_rid, outputs=[MagicMock(token_ids=[1])])], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [MagicMock(request_id=expected_rid, outputs=[MagicMock(token_ids=[1, 2, 3])])], + "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0}, + } + ) + + sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)] + omni.generate(prompts=[original_prompt], sampling_params_list=sp_list) + + # Check what was forwarded to stage 1 (decode) + # The connector sends tasks to stage 1's input queue + task = omni.stage_list[1]._in_q.get_nowait() + # The engine_inputs should contain the original prompt + engine_inputs = task.get("engine_inputs") + # For PD routing, the original prompt is wrapped in a list + if isinstance(engine_inputs, list): + assert original_prompt in engine_inputs + else: + assert engine_inputs == original_prompt + + def test_decode_kv_params_have_correct_flags(self, monkeypatch): + """Decode stage kv_transfer_params should have correct role flags.""" + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000004") + + def _extra_setup(mp, omni_module): + mp.setattr(uuid, "uuid4", lambda: test_uuid) + mp.setattr(omni_module, "uuid", uuid) + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + ], + extra_setup=_extra_setup, + ) + + expected_rid = f"0_{test_uuid}" + + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [MagicMock(request_id=expected_rid, outputs=[MagicMock(token_ids=[1])])], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [MagicMock(request_id=expected_rid, outputs=[MagicMock(token_ids=[1, 2, 3])])], + "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0}, + } + ) + + sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)] + omni.generate(prompts=["hello"], sampling_params_list=sp_list) + + # Check decode task's kv_transfer_params + task = omni.stage_list[1]._in_q.get_nowait() + kv_params = task["sampling_params"].extra_args["kv_transfer_params"] + assert kv_params["do_remote_prefill"] is True + assert kv_params["do_remote_decode"] is False + assert kv_params["transfer_id"] == f"xfer-{expected_rid}" + assert kv_params["remote_engine_id"] == "omni-thinker-prefill" + assert kv_params["remote_bootstrap_addr"] == "127.0.0.1:25201" + + +# =================================================================== +# Tests: KV params cleanup +# =================================================================== + + +class TestKVParamsCleanup: + def test_drop_cleans_up(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + omni._pd_kv_params_by_req["req-1"] = {"transfer_id": "xfer-1"} + omni._drop_pd_kv_params("req-1") + assert "req-1" not in omni._pd_kv_params_by_req + + def test_drop_nonexistent_is_noop(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + omni._drop_pd_kv_params("nonexistent") # should not raise + + def test_pop_returns_stored_params(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + stored = {"transfer_id": "xfer-1", "extra_field": "value"} + omni._pd_kv_params_by_req["req-1"] = stored + + result = omni._pop_pd_kv_params("req-1") + assert result == stored + assert "req-1" not in omni._pd_kv_params_by_req + + def test_pop_uses_fallback_when_no_stored(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + fallback = {"transfer_id": "xfer-fallback"} + result = omni._pop_pd_kv_params("req-1", fallback=fallback) + assert result == fallback + + +# =================================================================== +# Tests: Config YAML loads without error +# =================================================================== + + +class TestPDYAMLConfig: + def test_pd_yaml_loads(self): + """The PD separation YAML config should load without errors.""" + import os + + yaml_path = os.path.join( + os.path.dirname(__file__), + "../../vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml", + ) + yaml_path = os.path.abspath(yaml_path) + if not os.path.exists(yaml_path): + pytest.skip("PD separation YAML not found") + + from omegaconf import OmegaConf + + cfg = OmegaConf.load(yaml_path) + stages = cfg.stage_args + assert len(stages) == 4 + + # Prefill stage + assert stages[0].is_prefill_only is True + assert stages[0].final_output is False + assert stages[0].is_comprehension is True + + # Decode stage + assert stages[1].is_decode_only is True + assert stages[1].final_output is True + assert stages[1].final_output_type == "text" + assert stages[1].is_comprehension is True + assert 0 in stages[1].engine_input_source + + # KV transfer configs + assert stages[0].engine_args.kv_transfer_config.kv_role == "kv_producer" + assert stages[1].engine_args.kv_transfer_config.kv_role == "kv_consumer" + assert stages[0].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector" + assert stages[1].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector" + + +# =================================================================== +# Tests: MooncakeConnector monkey-patch +# =================================================================== + + +class TestMooncakeConnectorPatch: + """Tests for the embedded MooncakeConnector monkey-patch that fixes + the request-ID mismatch in PD disaggregation. + """ + + def test_stage_payload_includes_pd_flags(self, monkeypatch): + """init_stage_worker should include is_prefill_only / is_decode_only + in stage_payload so the worker process can decide whether to apply + the MooncakeConnector patch. + """ + from vllm_omni.entrypoints.omni_stage import OmniStage + + # Build a minimal stage config with PD flags + stage_cfg = _FakeStageConfig(_prefill_stage_cfg(stage_id=0)) + stage = OmniStage.__new__(OmniStage) + # Manually set required attributes (bypass __init__ which needs real config) + stage.stage_config = stage_cfg + stage.stage_id = 0 + stage.engine_args = stage_cfg.engine_args + stage.is_prefill_only = True + stage.is_decode_only = False + stage.stage_type = "llm" + stage.engine_input_source = [] + stage.final_output = False + stage.final_output_type = None + stage._shm_threshold_bytes = 65536 + stage._stage_init_timeout = 300 + stage._in_q = MagicMock() + stage._out_q = MagicMock() + stage._proc = None + + # Capture the stage_payload by monkeypatching the Process constructor + captured_payloads = [] + + class FakeCtx: + class Process: + def __init__(self, target=None, args=None): + self.target = target + # args[1] is stage_payload for _stage_worker + if args and len(args) >= 2: + captured_payloads.append(args[1]) + + def start(self): + pass + + stage.init_stage_worker("test-model", ctx=FakeCtx()) + + assert len(captured_payloads) == 1 + payload = captured_payloads[0] + assert payload["is_prefill_only"] is True + assert payload["is_decode_only"] is False + + def test_patch_creates_subclass(self): + """create_patched_mooncake_connector should return a class that is a + proper subclass of vLLM's MooncakeConnector (when available). + """ + try: + from vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector import ( + MooncakeConnector as OriginalMC, + ) + except ImportError: + pytest.skip("vLLM MooncakeConnector not available in this env") + + from vllm_omni.distributed.kv_transfer.patched_mooncake_connector import ( + create_patched_mooncake_connector, + ) + + PatchedCls = create_patched_mooncake_connector(engine_id="test-engine") + assert issubclass(PatchedCls, OriginalMC) + + def test_request_finished_returns_remote_request_id(self): + """The patched request_finished should inject remote_request_id + into kv_transfer_params. + """ + try: + from vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector import ( + MooncakeConnector as OriginalMC, + ) + except ImportError: + pytest.skip("vLLM MooncakeConnector not available in this env") + + from vllm_omni.distributed.kv_transfer.patched_mooncake_connector import ( + create_patched_mooncake_connector, + ) + + PatchedCls = create_patched_mooncake_connector(engine_id="prefill-0") + + # Create a mock instance without calling __init__ (avoids needing + # real vLLM config), then manually set attributes the method needs. + instance = PatchedCls.__new__(PatchedCls) + instance.engine_id = "prefill-0" + instance.remote_to_local_req = {} + + # Mock request object + fake_request = MagicMock() + fake_request.request_id = "chatcmpl-abc-9dd58560" + + # Mock super().request_finished to return a dict + original_rf = OriginalMC.request_finished + monkeypatch_result = {"do_remote_decode": True, "transfer_id": "xfer-1"} + + def _mock_super_rf(self, request, block_ids): + return dict(monkeypatch_result) + + OriginalMC.request_finished = _mock_super_rf + try: + result = PatchedCls.request_finished(instance, fake_request, [1, 2, 3]) + finally: + OriginalMC.request_finished = original_rf + + assert result is not None + assert result["remote_request_id"] == "chatcmpl-abc-9dd58560" + + def test_add_new_req_uses_remote_request_id(self): + """When load_remote_cache=True, the patched add_new_req should + store a PatchedRecvReqMeta with the remote_request_id from + kv_transfer_params. + """ + try: + from vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector import ( # noqa: F401 + MooncakeConnector as OriginalMC, + ) + except ImportError: + pytest.skip("vLLM MooncakeConnector not available in this env") + + from vllm_omni.distributed.kv_transfer.patched_mooncake_connector import ( + PatchedRecvReqMeta, + create_patched_mooncake_connector, + ) + + PatchedCls = create_patched_mooncake_connector(engine_id="decode-0") + + instance = PatchedCls.__new__(PatchedCls) + instance.engine_id = "decode-0" + instance.remote_to_local_req = {} + instance._reqs_need_recv = {} + + kv_params = { + "do_remote_prefill": True, + "remote_request_id": "chatcmpl-abc-9dd58560", + "transfer_id": "xfer-1", + } + instance.add_new_req( + request_id="chatcmpl-abc-ab52f3ea", + local_block_ids=[10, 20, 30], + kv_transfer_params=kv_params, + ) + + assert "chatcmpl-abc-ab52f3ea" in instance._reqs_need_recv + meta = instance._reqs_need_recv["chatcmpl-abc-ab52f3ea"] + assert isinstance(meta, PatchedRecvReqMeta) + assert meta.request_id == "chatcmpl-abc-ab52f3ea" + assert meta.remote_request_id == "chatcmpl-abc-9dd58560" + assert meta.local_block_ids == [10, 20, 30] + + +# =================================================================== +# Tests: Stop neutralization in prefill sampling params +# =================================================================== + + +class TestPrefillStopNeutralization: + """Tests that _prepare_prefill_sampling_params neutralizes stop + conditions to ensure finish_reason='length'. + """ + + def test_clears_stop_strings(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, stop=["", "STOP"]) + result = omni._prepare_prefill_sampling_params("req-1", sp) + assert result.stop == [] + + def test_clears_stop_token_ids(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, stop_token_ids=[151643, 151644]) + result = omni._prepare_prefill_sampling_params("req-1", sp) + assert result.stop_token_ids == [] + + def test_clears_include_stop_str_in_output(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, include_stop_str_in_output=True) + result = omni._prepare_prefill_sampling_params("req-1", sp) + assert result.include_stop_str_in_output is False + + def test_original_sp_unchanged(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, stop=[""], stop_token_ids=[151643]) + _ = omni._prepare_prefill_sampling_params("req-1", sp) + assert sp.stop == [""] + assert sp.stop_token_ids == [151643] + + +# =================================================================== +# Tests: Failure mode & memory leak prevention +# =================================================================== +# NOTE: Full generate()-level failure mode tests are removed for now. +# The _run_generation error handler (line 1344-1350 in omni.py) calls +# _drop_pd_kv_params but does not increment completed_requests, causing +# the while-loop to hang. These tests need to be revisited once the +# production error-handling path is fixed to properly terminate on +# stage errors. + + +# =================================================================== +# Tests: TP size validation +# =================================================================== + + +class TestTPSizeValidation: + """Tests that _validate_pd_separation_config checks tensor_parallel_size.""" + + def test_matching_tp_passes(self, monkeypatch): + """Same TP size should not raise.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["tensor_parallel_size"] = 2 + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["tensor_parallel_size"] = 2 + omni = _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg]) + assert omni._pd_separation_pair == (0, 1) + + def test_mismatched_tp_raises(self, monkeypatch): + """Different TP sizes should raise ValueError.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["tensor_parallel_size"] = 2 + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["tensor_parallel_size"] = 4 + with pytest.raises(ValueError, match="tensor_parallel_size"): + _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg]) + + def test_default_tp_no_error(self, monkeypatch): + """Stages without explicit TP (defaults to 1) should pass.""" + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + assert omni._pd_separation_pair == (0, 1) diff --git a/tests/model_executor/__init__.py b/tests/model_executor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/model_executor/stage_input_processors/__init__.py b/tests/model_executor/stage_input_processors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_stage_processors.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_stage_processors.py new file mode 100644 index 0000000000..841cdc4f0c --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_stage_processors.py @@ -0,0 +1,1592 @@ +"""Unit tests for Qwen3 Omni stage input processors. + +Tests the thinker->talker and talker->code2wav transition functions, +with special focus on PD (Prefill-Decode) disaggregation embedding merge +logic that is critical for correct audio generation. + +All tests run on CPU without requiring a GPU or model weights. +""" + +import warnings +from collections import defaultdict +from typing import Any +from unittest.mock import MagicMock + +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +# Suppress noisy DeprecationWarnings from optional Swig bindings. +warnings.filterwarnings( + "ignore", + message=r"builtin type SwigPy.*has no __module__ attribute", + category=DeprecationWarning, +) + +# --------------------------------------------------------------------------- +# Constants mirroring the production code +# --------------------------------------------------------------------------- +_EMBED_LAYER_KEY = "0" +_HIDDEN_LAYER_KEY = "24" + +# Token IDs used in _compute_talker_prompt_ids_length +_IM_START_TOKEN_ID = 151644 +_SYSTEM_TOKEN_ID = 8948 +_USER_TOKEN_ID = 872 +_ASSISTANT_TOKEN_ID = 77091 + + +# --------------------------------------------------------------------------- +# Fixture: force CPU device for thinker2talker / talker2code2wav +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _force_cpu_platform(monkeypatch): + """Ensure current_platform.device_type == 'cpu' for all tests. + + thinker2talker() uses ``torch.device(current_platform.device_type)`` + internally, which would fail on machines without CUDA. + """ + try: + import vllm.platforms as plat_mod + + monkeypatch.setattr(plat_mod.current_platform, "device_type", "cpu") + except (ImportError, AttributeError): + pass # vllm not installed with platform support — tests still work + + +# --------------------------------------------------------------------------- +# Fake stage / output helpers +# --------------------------------------------------------------------------- + + +class _FakeCompletionOutput: + """Minimal stand-in for vLLM CompletionOutput.""" + + def __init__( + self, + token_ids: list[int], + multimodal_output: dict[str, Any] | None = None, + ): + self.token_ids = token_ids + self.multimodal_output = multimodal_output or {} + + +class _FakeRequestOutput: + """Minimal stand-in for vLLM RequestOutput.""" + + def __init__( + self, + request_id: str, + prompt_token_ids: list[int], + outputs: list[_FakeCompletionOutput], + ): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.outputs = outputs + + +class _FakeStage: + """Lightweight stage stub for testing stage input processors.""" + + def __init__( + self, + stage_id: int = 0, + is_prefill_only: bool = False, + is_decode_only: bool = False, + engine_outputs: list | None = None, + ): + self.stage_id = stage_id + self.is_prefill_only = is_prefill_only + self.is_decode_only = is_decode_only + self.engine_outputs = engine_outputs + + +def _make_multimodal_output( + embed: torch.Tensor, + hidden: torch.Tensor, + *, + tts_bos: torch.Tensor | None = None, + tts_eos: torch.Tensor | None = None, + tts_pad: torch.Tensor | None = None, + codec_codes: torch.Tensor | None = None, +) -> dict[str, Any]: + """Build a multimodal_output dict like the thinker/talker produces.""" + mm: dict[str, Any] = { + _EMBED_LAYER_KEY: embed, + _HIDDEN_LAYER_KEY: hidden, + } + if tts_bos is not None: + mm["tts_bos_embed"] = tts_bos + if tts_eos is not None: + mm["tts_eos_embed"] = tts_eos + if tts_pad is not None: + mm["tts_pad_embed"] = tts_pad + if codec_codes is not None: + mm["code_predictor_codes"] = codec_codes + return mm + + +def _rand(rows: int, dim: int = 16) -> torch.Tensor: + return torch.randn(rows, dim) + + +# --------------------------------------------------------------------------- +# Helpers to build realistic token sequences +# --------------------------------------------------------------------------- + + +def _build_chat_token_ids( + user_len: int = 10, + assistant_generated_len: int = 5, +) -> tuple[list[int], list[int]]: + """Build (prompt_token_ids, output_token_ids) with realistic structure. + + Structure: <|im_start|>user <|im_start|>assistant + """ + # User turn: <|im_start|> user + user_turn = [_IM_START_TOKEN_ID, _USER_TOKEN_ID] + list(range(100, 100 + user_len)) + # Assistant turn prefix: <|im_start|> assistant + assistant_prefix = [_IM_START_TOKEN_ID, _ASSISTANT_TOKEN_ID] + prompt_token_ids = user_turn + assistant_prefix + + # Generated tokens + output_token_ids = list(range(200, 200 + assistant_generated_len)) + + return prompt_token_ids, output_token_ids + + +# =================================================================== +# Tests: _merge_pd_embeddings +# =================================================================== + + +class TestMergePDEmbeddings: + """Tests for _merge_pd_embeddings() -- the core PD embedding merge.""" + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + _merge_pd_embeddings, + ) + + return _merge_pd_embeddings + + def test_basic_merge_no_overlap(self): + """When prefill_len + decode_len == expected_total, no overlap.""" + merge = self._import() + prefill_emb = _rand(10) + prefill_hid = _rand(10) + decode_emb = _rand(5) + decode_hid = _rand(5) + + prefill_mm = { + _EMBED_LAYER_KEY: prefill_emb, + _HIDDEN_LAYER_KEY: prefill_hid, + } + + merged_emb, merged_hid = merge( + decode_emb, + decode_hid, + prefill_mm, + device=torch.device("cpu"), + expected_total=15, + ) + + assert merged_emb.shape[0] == 15 + assert merged_hid.shape[0] == 15 + # First 10 rows should come from prefill + assert torch.allclose(merged_emb[:10], prefill_emb) + # Last 5 rows should come from decode + assert torch.allclose(merged_emb[10:], decode_emb) + + def test_merge_with_overlap(self): + """When prefill_len + decode_len > expected_total, skip overlap.""" + merge = self._import() + prefill_emb = _rand(10) + prefill_hid = _rand(10) + decode_emb = _rand(8) + decode_hid = _rand(8) + + prefill_mm = { + _EMBED_LAYER_KEY: prefill_emb, + _HIDDEN_LAYER_KEY: prefill_hid, + } + + # 10 + 8 = 18, expected = 15, so overlap = 3 + merged_emb, merged_hid = merge( + decode_emb, + decode_hid, + prefill_mm, + device=torch.device("cpu"), + expected_total=15, + ) + + assert merged_emb.shape[0] == 15 + assert merged_hid.shape[0] == 15 + # First 10 come from prefill + assert torch.allclose(merged_emb[:10], prefill_emb) + # Last 5 come from decode[3:] (skipping 3 overlap tokens) + assert torch.allclose(merged_emb[10:], decode_emb[3:]) + + def test_merge_without_expected_total(self): + """Without expected_total, should concatenate fully (no overlap skip).""" + merge = self._import() + prefill_emb = _rand(10) + prefill_hid = _rand(10) + decode_emb = _rand(5) + decode_hid = _rand(5) + + prefill_mm = { + _EMBED_LAYER_KEY: prefill_emb, + _HIDDEN_LAYER_KEY: prefill_hid, + } + + merged_emb, merged_hid = merge( + decode_emb, + decode_hid, + prefill_mm, + device=torch.device("cpu"), + expected_total=None, + ) + + assert merged_emb.shape[0] == 15 + assert merged_hid.shape[0] == 15 + + def test_merge_preserves_hidden_consistency(self): + """Embedding and hidden state merges should have same length.""" + merge = self._import() + dim_emb, dim_hid = 16, 32 + prefill_emb = torch.randn(10, dim_emb) + prefill_hid = torch.randn(10, dim_hid) + decode_emb = torch.randn(5, dim_emb) + decode_hid = torch.randn(5, dim_hid) + + prefill_mm = { + _EMBED_LAYER_KEY: prefill_emb, + _HIDDEN_LAYER_KEY: prefill_hid, + } + + merged_emb, merged_hid = merge( + decode_emb, + decode_hid, + prefill_mm, + device=torch.device("cpu"), + expected_total=15, + ) + + assert merged_emb.shape == (15, dim_emb) + assert merged_hid.shape == (15, dim_hid) + + def test_empty_prefill_returns_decode_only(self): + """If prefill embeddings are empty, return decode unchanged.""" + merge = self._import() + decode_emb = _rand(5) + decode_hid = _rand(5) + + prefill_mm = { + _EMBED_LAYER_KEY: torch.empty(0, 16), + _HIDDEN_LAYER_KEY: torch.empty(0, 16), + } + + merged_emb, merged_hid = merge( + decode_emb, + decode_hid, + prefill_mm, + device=torch.device("cpu"), + expected_total=5, + ) + + assert torch.equal(merged_emb, decode_emb) + assert torch.equal(merged_hid, decode_hid) + + def test_empty_decode_returns_decode_only(self): + """If decode embeddings are empty, return decode unchanged.""" + merge = self._import() + decode_emb = torch.empty(0, 16) + decode_hid = torch.empty(0, 16) + + prefill_mm = { + _EMBED_LAYER_KEY: _rand(10), + _HIDDEN_LAYER_KEY: _rand(10), + } + + merged_emb, merged_hid = merge( + decode_emb, + decode_hid, + prefill_mm, + device=torch.device("cpu"), + expected_total=10, + ) + + # When decode is empty, function returns decode unchanged + assert torch.equal(merged_emb, decode_emb) + assert torch.equal(merged_hid, decode_hid) + + def test_missing_key_returns_decode_unchanged(self): + """If prefill_mm is missing required keys, return decode as-is.""" + merge = self._import() + decode_emb = _rand(5) + decode_hid = _rand(5) + + # Missing _EMBED_LAYER_KEY + prefill_mm = {_HIDDEN_LAYER_KEY: _rand(10)} + + merged_emb, merged_hid = merge( + decode_emb, + decode_hid, + prefill_mm, + device=torch.device("cpu"), + expected_total=15, + ) + + assert torch.equal(merged_emb, decode_emb) + assert torch.equal(merged_hid, decode_hid) + + def test_overlap_equals_decode_len_gives_prefill_only(self): + """If computed overlap >= decode_len, all decode tokens are skipped.""" + merge = self._import() + prefill_emb = _rand(10) + prefill_hid = _rand(10) + decode_emb = _rand(3) + decode_hid = _rand(3) + + prefill_mm = { + _EMBED_LAYER_KEY: prefill_emb, + _HIDDEN_LAYER_KEY: prefill_hid, + } + + # 10 + 3 = 13, expected = 10 -> overlap = 3 -> decode[3:] is empty + merged_emb, merged_hid = merge( + decode_emb, + decode_hid, + prefill_mm, + device=torch.device("cpu"), + expected_total=10, + ) + + # Result should be just the prefill embeddings + assert merged_emb.shape[0] == 10 + assert torch.allclose(merged_emb, prefill_emb) + + def test_expected_total_smaller_than_both_uses_no_overlap(self): + """When raw_total <= expected_total, overlap=0 (simple concat).""" + merge = self._import() + prefill_emb = _rand(5) + prefill_hid = _rand(5) + decode_emb = _rand(3) + decode_hid = _rand(3) + + prefill_mm = { + _EMBED_LAYER_KEY: prefill_emb, + _HIDDEN_LAYER_KEY: prefill_hid, + } + + # 5 + 3 = 8 <= 20 -> overlap = 0 + merged_emb, merged_hid = merge( + decode_emb, + decode_hid, + prefill_mm, + device=torch.device("cpu"), + expected_total=20, + ) + + assert merged_emb.shape[0] == 8 + assert torch.allclose(merged_emb[:5], prefill_emb) + assert torch.allclose(merged_emb[5:], decode_emb) + + +# =================================================================== +# Tests: _get_prefill_stage +# =================================================================== + + +class TestGetPrefillStage: + """Tests for _get_prefill_stage() -- prefill stage detection for PD mode.""" + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + _get_prefill_stage, + ) + + return _get_prefill_stage + + def test_returns_prefill_stage_when_pd_active(self): + """Should return the prefill stage when source is decode-only.""" + get_prefill = self._import() + + prefill = _FakeStage( + stage_id=0, + is_prefill_only=True, + engine_outputs=[MagicMock()], + ) + decode = _FakeStage(stage_id=1, is_decode_only=True) + stage_list = [prefill, decode] + + result = get_prefill(stage_list, source_stage_id=1) + assert result is prefill + + def test_returns_none_when_source_is_not_decode_only(self): + """Non-PD pipeline: source stage is not decode-only.""" + get_prefill = self._import() + + thinker = _FakeStage(stage_id=0) + talker = _FakeStage(stage_id=1) + stage_list = [thinker, talker] + + result = get_prefill(stage_list, source_stage_id=1) + assert result is None + + def test_returns_none_when_source_id_is_zero(self): + """First stage cannot have a preceding prefill.""" + get_prefill = self._import() + + stage_list = [_FakeStage(stage_id=0)] + result = get_prefill(stage_list, source_stage_id=0) + assert result is None + + def test_returns_none_when_prefill_has_no_outputs(self): + """Prefill stage exists but has no outputs yet.""" + get_prefill = self._import() + + prefill = _FakeStage( + stage_id=0, + is_prefill_only=True, + engine_outputs=None, + ) + decode = _FakeStage(stage_id=1, is_decode_only=True) + stage_list = [prefill, decode] + + result = get_prefill(stage_list, source_stage_id=1) + assert result is None + + def test_returns_none_when_prev_not_prefill_only(self): + """Previous stage is not marked as prefill-only.""" + get_prefill = self._import() + + normal = _FakeStage( + stage_id=0, + is_prefill_only=False, + engine_outputs=[MagicMock()], + ) + decode = _FakeStage(stage_id=1, is_decode_only=True) + stage_list = [normal, decode] + + result = get_prefill(stage_list, source_stage_id=1) + assert result is None + + +# =================================================================== +# Tests: thinker2talker (non-PD mode) +# =================================================================== + + +class TestThinker2TalkerNonPD: + """Tests for thinker2talker() in standard (non-PD) mode.""" + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + thinker2talker, + ) + + return thinker2talker + + def _make_thinker_output( + self, + prompt_len: int = 14, + output_len: int = 5, + dim: int = 16, + ) -> _FakeRequestOutput: + """Create a fake thinker output with correct embeddings.""" + prompt_ids, output_ids = _build_chat_token_ids( + user_len=prompt_len - 4, # subtract im_start + user + im_start + assistant + assistant_generated_len=output_len, + ) + total_len = len(prompt_ids) + len(output_ids) + + mm_output = _make_multimodal_output( + embed=_rand(total_len, dim), + hidden=_rand(total_len, dim), + tts_bos=_rand(1, dim), + tts_eos=_rand(1, dim), + tts_pad=_rand(1, dim), + ) + + return _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=prompt_ids, + outputs=[ + _FakeCompletionOutput( + token_ids=output_ids, + multimodal_output=mm_output, + ) + ], + ) + + def test_produces_talker_input(self): + """Basic thinker2talker should produce a valid OmniTokensPrompt.""" + thinker2talker = self._import() + thinker_out = self._make_thinker_output() + + thinker = _FakeStage(stage_id=0, engine_outputs=[thinker_out]) + stage_list = [thinker] + + results = thinker2talker(stage_list, engine_input_source=[0]) + + assert len(results) == 1 + result = results[0] + assert "prompt_token_ids" in result + assert "additional_information" in result + + info = result["additional_information"] + assert "thinker_embeddings" in info + assert "thinker_hidden_states" in info + assert "thinker_sequences" in info + assert "thinker_input_ids" in info + assert "tts_bos_embed" in info + assert "tts_eos_embed" in info + assert "tts_pad_embed" in info + + def test_prompt_token_ids_has_correct_length(self): + """prompt_token_ids length should match _compute_talker_prompt_ids_length.""" + thinker2talker = self._import() + thinker_out = self._make_thinker_output(prompt_len=14, output_len=5) + + thinker = _FakeStage(stage_id=0, engine_outputs=[thinker_out]) + results = thinker2talker([thinker], engine_input_source=[0]) + + # prompt_token_ids is [0]*prompt_len, should be all zeros + prompt_tids = results[0]["prompt_token_ids"] + assert all(t == 0 for t in prompt_tids) + assert len(prompt_tids) > 0 + + def test_thinker_sequences_is_full_concat(self): + """thinker_sequences should be prompt + output token ids.""" + thinker2talker = self._import() + thinker_out = self._make_thinker_output(prompt_len=14, output_len=5) + + thinker = _FakeStage(stage_id=0, engine_outputs=[thinker_out]) + results = thinker2talker([thinker], engine_input_source=[0]) + info = results[0]["additional_information"] + + expected_seq = thinker_out.prompt_token_ids + thinker_out.outputs[0].token_ids + assert info["thinker_sequences"] == expected_seq + + def test_thinker_input_ids_is_prompt_only(self): + """thinker_input_ids should be only the prompt token ids.""" + thinker2talker = self._import() + thinker_out = self._make_thinker_output() + + thinker = _FakeStage(stage_id=0, engine_outputs=[thinker_out]) + results = thinker2talker([thinker], engine_input_source=[0]) + info = results[0]["additional_information"] + + assert info["thinker_input_ids"] == thinker_out.prompt_token_ids + + def test_embeddings_shape_matches_total_sequence(self): + """Embeddings dim-0 should equal len(prompt) + len(output).""" + thinker2talker = self._import() + prompt_len, output_len, dim = 14, 5, 16 + thinker_out = self._make_thinker_output(prompt_len, output_len, dim) + + thinker = _FakeStage(stage_id=0, engine_outputs=[thinker_out]) + results = thinker2talker([thinker], engine_input_source=[0]) + info = results[0]["additional_information"] + + total = len(thinker_out.prompt_token_ids) + len(thinker_out.outputs[0].token_ids) + assert info["thinker_embeddings"].shape[0] == total + assert info["thinker_hidden_states"].shape[0] == total + + def test_invalid_stage_raises(self): + """Empty engine_input_source should raise.""" + thinker2talker = self._import() + with pytest.raises(ValueError, match="cannot be empty"): + thinker2talker([], engine_input_source=[]) + + def test_no_outputs_raises(self): + """Stage with no outputs should raise.""" + thinker2talker = self._import() + thinker = _FakeStage(stage_id=0, engine_outputs=None) + + with pytest.raises(RuntimeError, match="no outputs"): + thinker2talker([thinker], engine_input_source=[0]) + + def test_multiple_outputs(self): + """Multiple thinker outputs should produce multiple talker inputs.""" + thinker2talker = self._import() + out1 = self._make_thinker_output(prompt_len=14, output_len=3) + out2 = self._make_thinker_output(prompt_len=14, output_len=7) + + thinker = _FakeStage(stage_id=0, engine_outputs=[out1, out2]) + results = thinker2talker([thinker], engine_input_source=[0]) + assert len(results) == 2 + + +# =================================================================== +# Tests: thinker2talker (PD mode) +# =================================================================== + + +class TestThinker2TalkerPDMode: + """Tests for thinker2talker() when PD disaggregation is active. + + In PD mode: + - Stage 0 = prefill (is_prefill_only=True), has prompt embeddings + - Stage 1 = decode (is_decode_only=True), has generated embeddings + - thinker2talker should merge both to form the full sequence + """ + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + thinker2talker, + ) + + return thinker2talker + + def _make_pd_stages( + self, + prompt_len: int = 14, + output_len: int = 5, + dim: int = 16, + overlap: int = 0, + prefill_has_tts: bool = True, + decode_has_tts: bool = True, + ): + """Build prefill + decode stages for PD testing. + + Returns (stage_list, expected_total_embeddings). + """ + prompt_ids, output_ids = _build_chat_token_ids( + user_len=prompt_len - 4, + assistant_generated_len=output_len, + ) + total_len = len(prompt_ids) + len(output_ids) + + # Prefill: embeddings for prompt tokens + prefill_emb_len = len(prompt_ids) + prefill_mm = _make_multimodal_output( + embed=_rand(prefill_emb_len, dim), + hidden=_rand(prefill_emb_len, dim), + tts_bos=_rand(1, dim) if prefill_has_tts else None, + tts_eos=_rand(1, dim) if prefill_has_tts else None, + tts_pad=_rand(1, dim) if prefill_has_tts else None, + ) + + prefill_out = _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=prompt_ids, + outputs=[ + _FakeCompletionOutput( + token_ids=[output_ids[0]], # prefill produces 1 token + multimodal_output=prefill_mm, + ) + ], + ) + + # Decode: embeddings for generated tokens (+ possible overlap) + decode_emb_len = len(output_ids) + overlap + decode_mm = _make_multimodal_output( + embed=_rand(decode_emb_len, dim), + hidden=_rand(decode_emb_len, dim), + tts_bos=_rand(1, dim) if decode_has_tts else None, + tts_eos=_rand(1, dim) if decode_has_tts else None, + tts_pad=_rand(1, dim) if decode_has_tts else None, + ) + + decode_out = _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=prompt_ids, + outputs=[ + _FakeCompletionOutput( + token_ids=output_ids, + multimodal_output=decode_mm, + ) + ], + ) + + prefill_stage = _FakeStage( + stage_id=0, + is_prefill_only=True, + engine_outputs=[prefill_out], + ) + decode_stage = _FakeStage( + stage_id=1, + is_decode_only=True, + engine_outputs=[decode_out], + ) + + return [prefill_stage, decode_stage], total_len + + def test_pd_merge_basic(self): + """PD mode should merge prefill + decode embeddings.""" + thinker2talker = self._import() + stage_list, total_len = self._make_pd_stages( + prompt_len=14, + output_len=5, + overlap=0, + ) + + results = thinker2talker(stage_list, engine_input_source=[1]) + + assert len(results) == 1 + info = results[0]["additional_information"] + emb = info["thinker_embeddings"] + hid = info["thinker_hidden_states"] + # Merged length should equal prompt + output + assert emb.shape[0] == total_len + assert hid.shape[0] == total_len + + def test_pd_merge_with_overlap(self): + """PD mode with overlapping tokens should handle deduplication.""" + thinker2talker = self._import() + stage_list, total_len = self._make_pd_stages( + prompt_len=14, + output_len=5, + overlap=2, + ) + + results = thinker2talker(stage_list, engine_input_source=[1]) + info = results[0]["additional_information"] + emb = info["thinker_embeddings"] + + # Should still be total_len despite overlap + assert emb.shape[0] == total_len + + def test_pd_tts_fallback_to_prefill(self): + """When decode lacks TTS embeds, should fall back to prefill's.""" + thinker2talker = self._import() + stage_list, _ = self._make_pd_stages( + prompt_len=14, + output_len=5, + prefill_has_tts=True, + decode_has_tts=False, + ) + + results = thinker2talker(stage_list, engine_input_source=[1]) + info = results[0]["additional_information"] + + # TTS embeds should be present (from prefill fallback) + assert info["tts_bos_embed"] is not None + assert info["tts_eos_embed"] is not None + assert info["tts_pad_embed"] is not None + + def test_pd_tts_from_decode_when_available(self): + """When decode has TTS embeds, should use them (not prefill's).""" + thinker2talker = self._import() + stage_list, _ = self._make_pd_stages( + prompt_len=14, + output_len=5, + prefill_has_tts=True, + decode_has_tts=True, + ) + + # Get the decode stage's TTS embed for comparison + decode_tts_bos = stage_list[1].engine_outputs[0].outputs[0].multimodal_output["tts_bos_embed"] + + results = thinker2talker(stage_list, engine_input_source=[1]) + info = results[0]["additional_information"] + + # _tts() does: val.detach().to(device=cpu, dtype=torch.float) + # decode_tts_bos is already float32 on CPU from _rand(), so values match + assert torch.equal(info["tts_bos_embed"], decode_tts_bos.detach().float()) + + def test_pd_no_tts_anywhere_gives_none(self): + """When neither decode nor prefill has TTS embeds, result is None.""" + thinker2talker = self._import() + stage_list, _ = self._make_pd_stages( + prompt_len=14, + output_len=5, + prefill_has_tts=False, + decode_has_tts=False, + ) + + results = thinker2talker(stage_list, engine_input_source=[1]) + info = results[0]["additional_information"] + + assert info["tts_bos_embed"] is None + assert info["tts_eos_embed"] is None + assert info["tts_pad_embed"] is None + + def test_pd_sequences_are_full(self): + """In PD mode, thinker_sequences should still be full prompt + output.""" + thinker2talker = self._import() + stage_list, _ = self._make_pd_stages(prompt_len=14, output_len=5) + + results = thinker2talker(stage_list, engine_input_source=[1]) + info = results[0]["additional_information"] + + decode_out = stage_list[1].engine_outputs[0] + expected_seq = decode_out.prompt_token_ids + decode_out.outputs[0].token_ids + assert info["thinker_sequences"] == expected_seq + + def test_pd_prefill_merge_error_is_graceful(self): + """If prefill embeddings are corrupted, merge fails gracefully + and decode embeddings are used as-is (logged as warning).""" + thinker2talker = self._import() + stage_list, _ = self._make_pd_stages(prompt_len=14, output_len=5) + + # Corrupt prefill multimodal_output to trigger exception in merge + stage_list[0].engine_outputs[0].outputs[0].multimodal_output = "not-a-dict" + + # Should not raise; falls back to decode-only embeddings + results = thinker2talker(stage_list, engine_input_source=[1]) + assert len(results) == 1 + info = results[0]["additional_information"] + assert info["thinker_embeddings"] is not None + + +# =================================================================== +# Tests: talker2code2wav +# =================================================================== + + +class TestTalker2Code2Wav: + """Tests for talker2code2wav() -- the talker -> code2wav transition.""" + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + talker2code2wav, + ) + + return talker2code2wav + + def _make_talker_output( + self, + seq_len: int = 20, + num_quantizers: int = 8, + ) -> _FakeRequestOutput: + """Create a fake talker output with codec codes. + + The talker produces token_ids of length seq_len+1 (including a + start/padding token), and codec codes of shape + [num_quantizers, seq_len+1]. talker2code2wav slices the last + seq_len columns via ``codes[-seq_len:]``. + """ + codec_codes = torch.randint(0, 1024, (num_quantizers, seq_len + 1)) + mm_output = {"code_predictor_codes": codec_codes} + + # token_ids length = seq_len + 1 (code uses len(token_ids) - 1) + token_ids = list(range(seq_len + 1)) + + return _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=[0] * 10, # dummy prompt + outputs=[ + _FakeCompletionOutput( + token_ids=token_ids, + multimodal_output=mm_output, + ) + ], + ) + + def test_produces_code2wav_input(self): + """Should produce a valid OmniTokensPrompt with flattened codec codes.""" + talker2code2wav = self._import() + talker_out = self._make_talker_output(seq_len=20, num_quantizers=8) + + talker = _FakeStage(stage_id=1, engine_outputs=[talker_out]) + stage_list = [_FakeStage(stage_id=0), talker] + + results = talker2code2wav(stage_list, engine_input_source=[1]) + + assert len(results) == 1 + result = results[0] + assert "prompt_token_ids" in result + # Flattened: seq_len * num_quantizers + assert len(result["prompt_token_ids"]) == 20 * 8 + + def test_flattened_code_values_match_source(self): + """Flattened codes should match transposed + reshaped original.""" + talker2code2wav = self._import() + seq_len = 10 + num_q = 8 + talker_out = self._make_talker_output(seq_len=seq_len, num_quantizers=num_q) + + talker = _FakeStage(stage_id=1, engine_outputs=[talker_out]) + stage_list = [_FakeStage(stage_id=0), talker] + + results = talker2code2wav(stage_list, engine_input_source=[1]) + result_codes = results[0]["prompt_token_ids"] + + # Manually compute expected: codes[-seq_len:].transpose(0,1).reshape(-1) + original_codes = talker_out.outputs[0].multimodal_output["code_predictor_codes"] + expected = original_codes[-seq_len:].to(torch.long).transpose(0, 1).reshape(-1).tolist() + assert result_codes == expected + + def test_codes_are_all_ints(self): + """All flattened codes should be Python ints (for serialization).""" + talker2code2wav = self._import() + talker_out = self._make_talker_output(seq_len=15, num_quantizers=8) + + talker = _FakeStage(stage_id=1, engine_outputs=[talker_out]) + stage_list = [_FakeStage(stage_id=0), talker] + + results = talker2code2wav(stage_list, engine_input_source=[1]) + assert all(isinstance(c, int) for c in results[0]["prompt_token_ids"]) + + def test_multiple_talker_outputs(self): + """Should handle multiple talker outputs (batch).""" + talker2code2wav = self._import() + out1 = self._make_talker_output(seq_len=10, num_quantizers=8) + out2 = self._make_talker_output(seq_len=15, num_quantizers=8) + + talker = _FakeStage(stage_id=1, engine_outputs=[out1, out2]) + stage_list = [_FakeStage(stage_id=0), talker] + + results = talker2code2wav(stage_list, engine_input_source=[1]) + assert len(results) == 2 + assert len(results[0]["prompt_token_ids"]) == 10 * 8 + assert len(results[1]["prompt_token_ids"]) == 15 * 8 + + def test_16_quantizer_layers(self): + """Should work with 16-layer RVQ (Qwen3-Omni-MoE).""" + talker2code2wav = self._import() + talker_out = self._make_talker_output(seq_len=20, num_quantizers=16) + + talker = _FakeStage(stage_id=1, engine_outputs=[talker_out]) + stage_list = [_FakeStage(stage_id=0), talker] + + results = talker2code2wav(stage_list, engine_input_source=[1]) + assert len(results[0]["prompt_token_ids"]) == 20 * 16 + + def test_no_outputs_raises(self): + """Should raise when talker has no outputs.""" + talker2code2wav = self._import() + talker = _FakeStage(stage_id=1, engine_outputs=None) + stage_list = [_FakeStage(stage_id=0), talker] + + with pytest.raises(RuntimeError, match="no outputs"): + talker2code2wav(stage_list, engine_input_source=[1]) + + +# =================================================================== +# Tests: _compute_talker_prompt_ids_length +# =================================================================== + + +class TestComputeTalkerPromptIdsLength: + """Tests for _compute_talker_prompt_ids_length() -- determines + how many prompt tokens the talker needs for masking. + """ + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + _compute_talker_prompt_ids_length, + ) + + return _compute_talker_prompt_ids_length + + def test_user_turn_only(self): + """Single user turn: user_len + assistant(9).""" + compute = self._import() + prompt_ids, output_ids = _build_chat_token_ids( + user_len=10, + assistant_generated_len=5, + ) + full_seq = prompt_ids + output_ids + + info = { + "thinker_sequences": full_seq, + "thinker_input_ids": prompt_ids, + } + + result = compute(info, device="cpu") + # User turn: im_start(1) + user(1) + 10 tokens = 12 + # Assistant turn: 9 (hard-coded constant in source) + assert result == 12 + 9 + + def test_with_system_turn(self): + """System turn should be skipped in counting.""" + compute = self._import() + + # System turn + user turn + assistant + system_turn = [_IM_START_TOKEN_ID, _SYSTEM_TOKEN_ID] + [999] * 5 + user_turn = [_IM_START_TOKEN_ID, _USER_TOKEN_ID] + list(range(100, 110)) + assistant_prefix = [_IM_START_TOKEN_ID, _ASSISTANT_TOKEN_ID] + prompt_ids = system_turn + user_turn + assistant_prefix + output_ids = [200, 201, 202] + full_seq = prompt_ids + output_ids + + info = { + "thinker_sequences": full_seq, + "thinker_input_ids": prompt_ids, + } + + result = compute(info, device="cpu") + # System turn: skipped + # User turn: im_start(1) + user(1) + 10 tokens = 12 + # Assistant: 9 + assert result == 12 + 9 + + def test_returns_positive_for_empty_generation(self): + """Even with zero generated tokens, result should be positive.""" + compute = self._import() + prompt_ids, _ = _build_chat_token_ids(user_len=10, assistant_generated_len=0) + full_seq = list(prompt_ids) # no output tokens + + info = { + "thinker_sequences": full_seq, + "thinker_input_ids": prompt_ids, + } + + result = compute(info, device="cpu") + assert result > 0 + + +# =================================================================== +# Tests: _ensure_list +# =================================================================== + + +class TestEnsureList: + """Tests for _ensure_list() -- converts various list-like types.""" + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + _ensure_list, + ) + + return _ensure_list + + def test_regular_list(self): + ensure = self._import() + assert ensure([1, 2, 3]) == [1, 2, 3] + + def test_constant_list_with_x(self): + """ConstantList objects have a _x attribute.""" + ensure = self._import() + + class FakeConstantList: + def __init__(self, data): + self._x = data + + result = ensure(FakeConstantList([4, 5, 6])) + assert result == [4, 5, 6] + + def test_non_list_passthrough(self): + """Non-list, non-ConstantList values pass through unchanged.""" + ensure = self._import() + assert ensure("hello") == "hello" + assert ensure(42) == 42 + + def test_tuple_passthrough(self): + """Tuples don't have _x, aren't lists -> pass through.""" + ensure = self._import() + result = ensure((1, 2)) + assert result == (1, 2) + + +# =================================================================== +# Tests: _validate_stage_inputs +# =================================================================== + + +class TestValidateStageInputs: + """Tests for _validate_stage_inputs().""" + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + _validate_stage_inputs, + ) + + return _validate_stage_inputs + + def test_returns_outputs_on_valid(self): + validate = self._import() + outputs = [MagicMock()] + stage = _FakeStage(stage_id=0, engine_outputs=outputs) + + result = validate([stage], engine_input_source=[0]) + assert result is outputs + + def test_empty_source_raises(self): + validate = self._import() + with pytest.raises(ValueError, match="cannot be empty"): + validate([], engine_input_source=[]) + + def test_invalid_stage_id_raises(self): + validate = self._import() + with pytest.raises(IndexError, match="Invalid stage_id"): + validate([_FakeStage(stage_id=0)], engine_input_source=[5]) + + def test_no_outputs_raises(self): + validate = self._import() + stage = _FakeStage(stage_id=0, engine_outputs=None) + with pytest.raises(RuntimeError, match="no outputs"): + validate([stage], engine_input_source=[0]) + + +# =================================================================== +# Tests: talker2code2wav_async_chunk +# =================================================================== + + +class TestTalker2Code2WavAsyncChunk: + """Tests for talker2code2wav_async_chunk() -- chunked codec code processing.""" + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + talker2code2wav_async_chunk, + ) + + return talker2code2wav_async_chunk + + def _make_transfer_manager(self): + mgr = MagicMock() + mgr.code_prompt_token_ids = defaultdict(list) + return mgr + + def _make_request(self, request_id="req-0", is_finished=False): + req = MagicMock() + req.external_req_id = request_id + req.is_finished = MagicMock(return_value=is_finished) + return req + + def test_returns_none_when_no_codec_codes_key(self): + """Should return None when pooling output lacks code_predictor_codes.""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req = self._make_request() + + result = chunk_fn(mgr, {}, req) + assert result is None + + def test_returns_none_for_none_codes(self): + """Should return None when code_predictor_codes is None.""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req = self._make_request() + + result = chunk_fn(mgr, {"code_predictor_codes": None}, req) + assert result is None + + def test_returns_none_for_empty_tensor(self): + """Should return None for empty tensor.""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req = self._make_request() + + result = chunk_fn(mgr, {"code_predictor_codes": torch.empty(0)}, req) + assert result is None + + def test_returns_none_for_all_zero_tensor(self): + """Should return None when all codes are zero.""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req = self._make_request() + + codes = torch.zeros(8, 10, dtype=torch.long) + result = chunk_fn(mgr, {"code_predictor_codes": codes}, req) + assert result is None + + def test_returns_info_when_finished(self): + """Should return info dict when request is finished (any chunk count).""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req = self._make_request(is_finished=True) + + codes = torch.randint(1, 100, (8, 5)) + result = chunk_fn(mgr, {"code_predictor_codes": codes}, req) + + assert result is not None + assert "code_predictor_codes" in result + assert "finished" in result + assert result["finished"].item() is True + + def test_buffers_until_chunk_boundary(self): + """Should buffer codec codes and only emit at chunk_size=25 boundary.""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req = self._make_request(request_id="req-0", is_finished=False) + + codes = torch.randint(1, 100, (8, 5)) + + # 24 calls: not at chunk boundary (24 % 25 != 0), not finished -> None + for _ in range(24): + result = chunk_fn(mgr, {"code_predictor_codes": codes}, req) + assert result is None + + # 25th call hits boundary (25 % 25 == 0) -> output emitted + result = chunk_fn(mgr, {"code_predictor_codes": codes}, req) + assert result is not None + assert "code_predictor_codes" in result + + +# =================================================================== +# Tests: thinker2talker_async_chunk +# =================================================================== + + +class TestThinker2TalkerAsyncChunk: + """Tests for thinker2talker_async_chunk() -- pooling-based transition.""" + + def _import(self): + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + thinker2talker_async_chunk, + ) + + return thinker2talker_async_chunk + + def _make_transfer_manager(self): + mgr = MagicMock() + mgr.put_req_chunk = defaultdict(int) + mgr.request_payload = {} + return mgr + + def _make_request(self, request_id="req-0", is_finished=True): + req = MagicMock() + req.external_req_id = request_id + req.all_token_ids = [1, 2, 3, 4, 5] + req.prompt_token_ids = [1, 2, 3] + req.output_token_ids = [4, 5] + req.is_finished = MagicMock(return_value=is_finished) + return req + + def test_first_chunk_finished_returns_info(self): + """First chunk with finished request should return full info.""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req = self._make_request(is_finished=True) + + pooling_output = { + _EMBED_LAYER_KEY: _rand(5), + _HIDDEN_LAYER_KEY: _rand(5), + "tts_bos_embed": _rand(1), + "tts_eos_embed": _rand(1), + "tts_pad_embed": _rand(1), + } + + result = chunk_fn(mgr, pooling_output, req) + + assert result is not None + assert "thinker_embeddings" in result + assert "thinker_hidden_states" in result + assert "thinker_sequences" in result + assert "thinker_input_ids" in result + assert "tts_bos_embed" in result + assert result["finished"].item() is True + + def test_first_chunk_not_finished_stores_payload(self): + """First chunk with unfinished request should store payload, return None.""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req = self._make_request(is_finished=False) + + pooling_output = { + _EMBED_LAYER_KEY: _rand(3), + _HIDDEN_LAYER_KEY: _rand(3), + "tts_bos_embed": _rand(1), + "tts_eos_embed": _rand(1), + "tts_pad_embed": _rand(1), + } + + result = chunk_fn(mgr, pooling_output, req) + + assert result is None + assert "req-0" in mgr.request_payload + + def test_second_chunk_merges_with_stored(self): + """Second call with stored payload should concatenate embeddings.""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req1 = self._make_request(is_finished=False) + + # First call: store partial + pooling_output_1 = { + _EMBED_LAYER_KEY: _rand(3), + _HIDDEN_LAYER_KEY: _rand(3), + "tts_bos_embed": _rand(1), + "tts_eos_embed": _rand(1), + "tts_pad_embed": _rand(1), + } + chunk_fn(mgr, pooling_output_1, req1) + + # Second call: finished request, should merge with stored + req2 = self._make_request(is_finished=True) + pooling_output_2 = { + _EMBED_LAYER_KEY: _rand(2), + _HIDDEN_LAYER_KEY: _rand(2), + "tts_bos_embed": _rand(1), + "tts_eos_embed": _rand(1), + "tts_pad_embed": _rand(1), + } + result = chunk_fn(mgr, pooling_output_2, req2) + + assert result is not None + # Embeddings should be merged: 3 + 2 = 5 + assert result["thinker_embeddings"].shape[0] == 5 + assert result["thinker_hidden_states"].shape[0] == 5 + + def test_sequences_are_all_token_ids(self): + """thinker_sequences should be all_token_ids (prompt + decode).""" + chunk_fn = self._import() + mgr = self._make_transfer_manager() + req = self._make_request(is_finished=True) + req.all_token_ids = [10, 20, 30, 40, 50] + + pooling_output = { + _EMBED_LAYER_KEY: _rand(5), + _HIDDEN_LAYER_KEY: _rand(5), + "tts_bos_embed": _rand(1), + "tts_eos_embed": _rand(1), + "tts_pad_embed": _rand(1), + } + + result = chunk_fn(mgr, pooling_output, req) + assert result["thinker_sequences"] == [10, 20, 30, 40, 50] + + +# =================================================================== +# Tests: Full PD audio pipeline integration (unit-level) +# =================================================================== + + +class TestPDAudioPipelineIntegration: + """Integration-style tests verifying the full PD audio data flow: + prefill -> decode -> thinker2talker -> talker2code2wav + + These test the complete chain at the unit level (no GPU, no model). + """ + + def test_full_pd_audio_chain(self): + """Simulate a full PD audio pipeline and verify data flows correctly + through all transition functions.""" + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + talker2code2wav, + thinker2talker, + ) + + dim = 16 + prompt_ids, output_ids = _build_chat_token_ids( + user_len=10, + assistant_generated_len=5, + ) + total_len = len(prompt_ids) + len(output_ids) + + # --- Stage 0: Prefill (produces prompt embeddings) --- + prefill_mm = _make_multimodal_output( + embed=_rand(len(prompt_ids), dim), + hidden=_rand(len(prompt_ids), dim), + tts_bos=_rand(1, dim), + tts_eos=_rand(1, dim), + tts_pad=_rand(1, dim), + ) + prefill_out = _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=prompt_ids, + outputs=[ + _FakeCompletionOutput( + token_ids=[output_ids[0]], + multimodal_output=prefill_mm, + ) + ], + ) + + # --- Stage 1: Decode (produces generated embeddings) --- + decode_mm = _make_multimodal_output( + embed=_rand(len(output_ids), dim), + hidden=_rand(len(output_ids), dim), + tts_bos=_rand(1, dim), + tts_eos=_rand(1, dim), + tts_pad=_rand(1, dim), + ) + decode_out = _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=prompt_ids, + outputs=[ + _FakeCompletionOutput( + token_ids=output_ids, + multimodal_output=decode_mm, + ) + ], + ) + + prefill_stage = _FakeStage( + stage_id=0, + is_prefill_only=True, + engine_outputs=[prefill_out], + ) + decode_stage = _FakeStage( + stage_id=1, + is_decode_only=True, + engine_outputs=[decode_out], + ) + + # --- thinker2talker: merge prefill + decode -> talker input --- + stage_list_t2t = [prefill_stage, decode_stage] + talker_inputs = thinker2talker(stage_list_t2t, engine_input_source=[1]) + + assert len(talker_inputs) == 1 + info = talker_inputs[0]["additional_information"] + assert info["thinker_embeddings"].shape[0] == total_len + assert info["thinker_hidden_states"].shape[0] == total_len + assert info["tts_bos_embed"] is not None + + # --- Stage 2: Talker (produces codec codes) --- + num_q = 8 + talker_seq_len = 30 + codec_codes = torch.randint(1, 1024, (num_q, talker_seq_len + 1)) + talker_mm = {"code_predictor_codes": codec_codes} + talker_out = _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=[0] * 10, + outputs=[ + _FakeCompletionOutput( + token_ids=list(range(talker_seq_len + 1)), + multimodal_output=talker_mm, + ) + ], + ) + + talker_stage = _FakeStage(stage_id=2, engine_outputs=[talker_out]) + + # --- talker2code2wav: codec codes -> code2wav input --- + stage_list_t2c = [prefill_stage, decode_stage, talker_stage] + code2wav_inputs = talker2code2wav(stage_list_t2c, engine_input_source=[2]) + + assert len(code2wav_inputs) == 1 + flattened_codes = code2wav_inputs[0]["prompt_token_ids"] + assert len(flattened_codes) == talker_seq_len * num_q + # All codes should be valid integers + assert all(isinstance(c, int) for c in flattened_codes) + + def test_pd_audio_preserves_prompt_context(self): + """Verify that in PD mode, the merged embeddings preserve the full + prompt context that the talker needs for coherent audio generation. + + Uses distinguishable embeddings (positive for prefill, negative for + decode) to verify the merge places them correctly. + """ + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + thinker2talker, + ) + + dim = 16 + prompt_ids, output_ids = _build_chat_token_ids( + user_len=20, + assistant_generated_len=10, + ) + + # Create distinguishable prefill embeddings (all positive) + prefill_emb = torch.abs(_rand(len(prompt_ids), dim)) + 1.0 + prefill_hid = torch.abs(_rand(len(prompt_ids), dim)) + 1.0 + + # Create distinguishable decode embeddings (all negative) + decode_emb = -torch.abs(_rand(len(output_ids), dim)) - 1.0 + decode_hid = -torch.abs(_rand(len(output_ids), dim)) - 1.0 + + prefill_mm = _make_multimodal_output( + embed=prefill_emb, + hidden=prefill_hid, + tts_bos=_rand(1, dim), + tts_eos=_rand(1, dim), + tts_pad=_rand(1, dim), + ) + decode_mm = _make_multimodal_output( + embed=decode_emb, + hidden=decode_hid, + tts_bos=_rand(1, dim), + tts_eos=_rand(1, dim), + tts_pad=_rand(1, dim), + ) + + prefill_out = _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=prompt_ids, + outputs=[ + _FakeCompletionOutput( + token_ids=[output_ids[0]], + multimodal_output=prefill_mm, + ) + ], + ) + decode_out = _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=prompt_ids, + outputs=[ + _FakeCompletionOutput( + token_ids=output_ids, + multimodal_output=decode_mm, + ) + ], + ) + + prefill_stage = _FakeStage( + stage_id=0, + is_prefill_only=True, + engine_outputs=[prefill_out], + ) + decode_stage = _FakeStage( + stage_id=1, + is_decode_only=True, + engine_outputs=[decode_out], + ) + + results = thinker2talker([prefill_stage, decode_stage], engine_input_source=[1]) + merged_emb = results[0]["additional_information"]["thinker_embeddings"] + + # First part (prompt) should be from prefill (positive values) + prompt_part = merged_emb[: len(prompt_ids)] + assert (prompt_part > 0).all(), "Prompt embeddings should come from prefill" + + # Second part (generated) should be from decode (negative values) + decode_part = merged_emb[len(prompt_ids) :] + assert (decode_part < 0).all(), "Generated embeddings should come from decode" + + def test_non_pd_audio_chain(self): + """Verify the non-PD path also works end-to-end for comparison.""" + from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( + talker2code2wav, + thinker2talker, + ) + + dim = 16 + prompt_ids, output_ids = _build_chat_token_ids(user_len=10, assistant_generated_len=5) + total_len = len(prompt_ids) + len(output_ids) + + # Single thinker stage (no PD split) + thinker_mm = _make_multimodal_output( + embed=_rand(total_len, dim), + hidden=_rand(total_len, dim), + tts_bos=_rand(1, dim), + tts_eos=_rand(1, dim), + tts_pad=_rand(1, dim), + ) + thinker_out = _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=prompt_ids, + outputs=[ + _FakeCompletionOutput( + token_ids=output_ids, + multimodal_output=thinker_mm, + ) + ], + ) + thinker_stage = _FakeStage(stage_id=0, engine_outputs=[thinker_out]) + + # thinker2talker (non-PD: source_stage_id=0, no prefill stage) + talker_inputs = thinker2talker([thinker_stage], engine_input_source=[0]) + assert len(talker_inputs) == 1 + info = talker_inputs[0]["additional_information"] + assert info["thinker_embeddings"].shape[0] == total_len + + # talker2code2wav + num_q = 8 + talker_seq = 25 + codec_codes = torch.randint(1, 1024, (num_q, talker_seq + 1)) + talker_out2 = _FakeRequestOutput( + request_id="req-0", + prompt_token_ids=[0] * 10, + outputs=[ + _FakeCompletionOutput( + token_ids=list(range(talker_seq + 1)), + multimodal_output={"code_predictor_codes": codec_codes}, + ) + ], + ) + talker_stage = _FakeStage(stage_id=1, engine_outputs=[talker_out2]) + + c2w_inputs = talker2code2wav([thinker_stage, talker_stage], engine_input_source=[1]) + assert len(c2w_inputs) == 1 + assert len(c2w_inputs[0]["prompt_token_ids"]) == talker_seq * num_q diff --git a/vllm_omni/distributed/kv_transfer/__init__.py b/vllm_omni/distributed/kv_transfer/__init__.py new file mode 100644 index 0000000000..f8914a428b --- /dev/null +++ b/vllm_omni/distributed/kv_transfer/__init__.py @@ -0,0 +1,13 @@ +"""Patched KV transfer connectors for PD disaggregation. + +This package provides monkey-patched versions of vLLM's native KV transfer +connectors (e.g. MooncakeConnector) that fix the request-ID mismatch problem +in prefill-decode disaggregation. + +vLLM's ``InputProcessor.assign_request_id()`` appends a random 8-char suffix +to each request ID internally. The prefill engine stores KV under its own +suffix, but the decode engine generates a *different* suffix — so it can never +find the KV data. The patched connector threads the prefill engine's internal +``remote_request_id`` through ``kv_transfer_params`` so the decode side can +reference the correct KV entry. +""" diff --git a/vllm_omni/distributed/kv_transfer/monkey_patch.py b/vllm_omni/distributed/kv_transfer/monkey_patch.py new file mode 100644 index 0000000000..e6c05b95c0 --- /dev/null +++ b/vllm_omni/distributed/kv_transfer/monkey_patch.py @@ -0,0 +1,100 @@ +"""Monkey-patch vLLM's native ``MooncakeConnector`` with the patched version +that fixes request-ID mismatch in PD disaggregation. + +Call :func:`apply_mooncake_connector_patch` at stage startup (before the +vLLM engine is constructed) so that vLLM's own ``MooncakeConnector`` +reference resolves to our patched subclass. + +The patching follows the same ``sys.modules`` iteration pattern used by +``vllm_omni/patch.py`` for other class replacements. +""" + +from __future__ import annotations + +import logging +import sys + +logger = logging.getLogger(__name__) + +_patched: bool = False + + +def apply_mooncake_connector_patch(engine_id: str | None = None) -> bool: + """Replace vLLM's ``MooncakeConnector`` with the patched version. + + Parameters + ---------- + engine_id: + Optional engine identifier passed through to the patched class + (used in log messages for debugging PD disaggregation). + + Returns + ------- + bool + ``True`` if the patch was applied (or was already applied), + ``False`` if vLLM's MooncakeConnector could not be imported + (e.g. vLLM not installed or version mismatch). + """ + global _patched + if _patched: + logger.debug("[monkey_patch] MooncakeConnector patch already applied, skipping") + return True + + # --- 0. Version compatibility check ---------------------------------- + _VLLM_MIN_VERSION = "0.8.0" + try: + import vllm + + if hasattr(vllm, "__version__") and vllm.__version__ < _VLLM_MIN_VERSION: + logger.warning( + "[monkey_patch] vLLM %s < %s — MooncakeConnector patch may be incompatible", + vllm.__version__, + _VLLM_MIN_VERSION, + ) + except Exception: + pass + + # --- 1. Import the original class ----------------------------------- + try: + from vllm.distributed.kv_transfer.kv_connector.v1 import ( + mooncake_connector as _mc_module, + ) + + _OriginalMooncakeConnector = _mc_module.MooncakeConnector + except (ImportError, AttributeError) as exc: + logger.warning( + "[monkey_patch] Cannot import vLLM MooncakeConnector — patch NOT applied: %s", + exc, + ) + return False + + # --- 2. Build patched class ----------------------------------------- + from vllm_omni.distributed.kv_transfer.patched_mooncake_connector import ( + create_patched_mooncake_connector, + ) + + PatchedClass = create_patched_mooncake_connector(engine_id=engine_id) + + # --- 3. Replace in the defining module ------------------------------ + _mc_module.MooncakeConnector = PatchedClass + logger.info( + "[monkey_patch] Replaced MooncakeConnector in %s (engine_id=%s)", + _mc_module.__name__, + engine_id, + ) + + # --- 4. Replace in all already-imported modules --------------------- + # Same pattern as vllm_omni/patch.py:18-33 + for module_name, module in sys.modules.items(): + if "vllm" not in module_name: + continue + if hasattr(module, "MooncakeConnector") and module.MooncakeConnector is _OriginalMooncakeConnector: + module.MooncakeConnector = PatchedClass + logger.debug( + "[monkey_patch] Also patched MooncakeConnector in %s", + module_name, + ) + + _patched = True + logger.info("[monkey_patch] MooncakeConnector patch applied successfully") + return True diff --git a/vllm_omni/distributed/kv_transfer/patched_mooncake_connector.py b/vllm_omni/distributed/kv_transfer/patched_mooncake_connector.py new file mode 100644 index 0000000000..837ed79528 --- /dev/null +++ b/vllm_omni/distributed/kv_transfer/patched_mooncake_connector.py @@ -0,0 +1,272 @@ +"""Patched MooncakeConnector that threads ``remote_request_id`` through +KV transfer params so the decode engine can look up the KV cache stored +by the prefill engine under its (different) internal request ID. + +Usage:: + + from vllm_omni.distributed.kv_transfer.patched_mooncake_connector import ( + create_patched_mooncake_connector, + ) + + PatchedCls = create_patched_mooncake_connector(engine_id="my-engine") + # PatchedCls is a subclass of vLLM's MooncakeConnector +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + pass + + +# --------------------------------------------------------------------------- +# Patched metadata dataclass +# --------------------------------------------------------------------------- + + +@dataclass +class PatchedRecvReqMeta: + """Extended receive-request metadata that carries the prefill engine's + internal request ID (``remote_request_id``) alongside the local one. + """ + + request_id: str + remote_request_id: str + local_block_ids: list[int] + kv_transfer_params: dict[str, Any] + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_patched_mooncake_connector(engine_id: str | None = None): + """Return a *subclass* of vLLM's ``MooncakeConnector`` with + ``remote_request_id`` support baked in. + + The import is lazy so this module can be safely imported even when + vLLM is not installed (e.g. during linting / light tests). + + Parameters + ---------- + engine_id: + Optional identifier for this engine instance (for logging). + + Returns + ------- + type + A class that is a proper subclass of + ``vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector.MooncakeConnector``. + """ + # Lazy import — the GPU environment has vLLM; CI / linting may not. + from vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector import ( + MooncakeConnector as _OriginalMooncakeConnector, + ) + + class PatchedMooncakeConnector(_OriginalMooncakeConnector): + """MooncakeConnector subclass that fixes the request-ID mismatch + in prefill-decode disaggregation. + + Key changes + ----------- + * ``request_finished`` injects ``remote_request_id`` (the prefill + engine's internal request ID) into ``kv_transfer_params`` so the + orchestrator can forward it to the decode engine. + * ``add_new_req`` uses ``remote_request_id`` from + ``kv_transfer_params`` when ``load_remote_cache=True``, creating a + ``PatchedRecvReqMeta`` instead of the default ``RecvReqMeta``. + * ``group_kv_pull`` sends ZMQ requests using + ``meta.remote_request_id``. + * ``receive_kv`` maps the remote ID back to the local ID after + transfer. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.engine_id: str | None = engine_id + # remote_request_id → local_request_id mapping for in-flight pulls + self.remote_to_local_req: dict[str, str] = {} + logger.info( + "[PatchedMooncakeConnector] Initialized (engine_id=%s)", + self.engine_id, + ) + + # ---- prefill side: inject remote_request_id into output ---- + + def request_finished( + self, + request: Any, + block_ids: list[int], + ) -> dict[str, Any] | None: + """Call the original ``request_finished``, then patch the returned + ``kv_transfer_params`` dict with ``remote_request_id``. + + The original implementation may store the request in + ``_reqs_need_send`` as a ``(Request, list[int])`` tuple; we also + normalise that to just ``list[int]`` to prevent downstream + serialisation issues. + """ + result = super().request_finished(request, block_ids) + + # --- normalise _reqs_need_send values ----------------------- + # The base class may store (Request, list[int]) tuples. Down- + # stream code that iterates over the dict values sometimes + # expects bare list[int]. Normalise eagerly so we don't hit + # "tuple is not subscriptable" errors later. + req_id = getattr(request, "request_id", None) + if req_id and hasattr(self, "_reqs_need_send"): + entry = self._reqs_need_send.get(req_id) + if isinstance(entry, tuple) and len(entry) == 2: + self._reqs_need_send[req_id] = entry[1] + + # --- inject remote_request_id into kv_transfer_params ------- + if result is not None and isinstance(result, dict): + result["remote_request_id"] = req_id or "NOT_SET" + # Ensure host/port are present for decode-side look-up + if hasattr(self, "side_channel_host"): + result.setdefault("remote_host", self.side_channel_host) + if hasattr(self, "side_channel_port"): + result.setdefault("remote_port", self.side_channel_port) + logger.debug( + "[PatchedMooncakeConnector] request_finished: req_id=%s remote_request_id=%s engine_id=%s", + req_id, + result.get("remote_request_id"), + self.engine_id, + ) + + return result + + # ---- decode side: use remote_request_id for look-up ---- + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Call ``super().add_new_req()`` for all requests, then layer + PD-specific ``PatchedRecvReqMeta`` on top for decode-side + (``load_remote_cache=True``) requests. + + This ensures any future logic added to the base method is + always executed, while still providing the + ``remote_request_id`` mapping needed for PD disaggregation. + """ + # Always call super() so base-class bookkeeping is preserved. + super().add_new_req( + request_id, + local_block_ids, + kv_transfer_params, + **kwargs, + ) + + kv_transfer_params = kv_transfer_params or {} + load_remote_cache = kv_transfer_params.get( + "do_remote_prefill", + kv_transfer_params.get("load_remote_cache", False), + ) + + if load_remote_cache: + remote_request_id = kv_transfer_params.get("remote_request_id", request_id) + meta = PatchedRecvReqMeta( + request_id=request_id, + remote_request_id=remote_request_id, + local_block_ids=local_block_ids, + kv_transfer_params=kv_transfer_params, + ) + # Override the entry created by super() with our patched + # version that carries remote_request_id. + if not hasattr(self, "_reqs_need_recv"): + self._reqs_need_recv = {} + self._reqs_need_recv[request_id] = meta + logger.debug( + "[PatchedMooncakeConnector] add_new_req (recv): local_id=%s remote_id=%s engine_id=%s", + request_id, + remote_request_id, + self.engine_id, + ) + + def group_kv_pull(self, metadata: Any | None = None) -> None: + """Override to use ``meta.remote_request_id`` as the ZMQ look-up + key instead of the local request ID. + + We use a *save-patch-restore* pattern: save the original + ``_reqs_need_recv``, replace it with a re-keyed copy (using + remote IDs), call ``super().group_kv_pull()`` which reads + from ``self._reqs_need_recv`` directly (so we can't use a + copy-and-return approach), then restore unconsumed entries + to their original local keys. + """ + if not hasattr(self, "_reqs_need_recv") or not self._reqs_need_recv: + return + + # Build a patched copy; keep the original for restoration. + original_recv = self._reqs_need_recv.copy() + patched_recv: dict[str, Any] = {} + + for local_id, meta in original_recv.items(): + if isinstance(meta, PatchedRecvReqMeta): + remote_id = meta.remote_request_id + self.remote_to_local_req[remote_id] = local_id + logger.debug( + "[PatchedMooncakeConnector] group_kv_pull: remote_id=%s -> local_id=%s", + remote_id, + local_id, + ) + # Use remote_id as key so the base class ZMQ logic + # looks up KV under the prefill engine's request ID. + patched_meta = type(meta)( + request_id=remote_id, + remote_request_id=remote_id, + local_block_ids=meta.local_block_ids, + kv_transfer_params=meta.kv_transfer_params, + ) + patched_recv[remote_id] = patched_meta + else: + patched_recv[local_id] = meta + + # Swap in the patched dict, delegate to the base class, then + # restore entries that weren't consumed. + self._reqs_need_recv = patched_recv + super().group_kv_pull(metadata) + + # Restore any entries that the base class didn't consume + # (e.g. still pending transfer) back to their original keys. + for remote_id, local_id in list(self.remote_to_local_req.items()): + if remote_id in self._reqs_need_recv: + entry = self._reqs_need_recv.pop(remote_id) + self._reqs_need_recv[local_id] = original_recv.get(local_id, entry) + + def receive_kv(self, path: Any = None, req_blocks: Any = None) -> Any: + """After the base class completes the ZMQ transfer, map + ``remote_id`` back to ``local_id`` in any result structures. + """ + result = super().receive_kv(path, req_blocks) + + # Clean up any completed remote→local mappings + if self.remote_to_local_req: + completed = [] + for remote_id, local_id in self.remote_to_local_req.items(): + if not hasattr(self, "_reqs_need_recv") or local_id not in self._reqs_need_recv: + completed.append(remote_id) + for remote_id in completed: + popped_local = self.remote_to_local_req.pop(remote_id, None) + logger.debug( + "[PatchedMooncakeConnector] receive_kv done: remote_id=%s -> local_id=%s", + remote_id, + popped_local, + ) + + return result + + # Preserve the original qualname for isinstance checks in vLLM + PatchedMooncakeConnector.__qualname__ = _OriginalMooncakeConnector.__qualname__ + + return PatchedMooncakeConnector diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 19c8b07ace..62b058920d 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -303,6 +303,22 @@ async def generate( if sampling_params_list is None: sampling_params_list = self.default_sampling_params_list + # PD disaggregation: auto-duplicate thinker sampling params for + # the decode stage when the caller provides N-1 params. + if self._pd_separation_pair is not None and len(sampling_params_list) == len(self.stage_list) - 1: + p_id, d_id = self._pd_separation_pair + sp_list = list(sampling_params_list) + sp_list.insert(d_id, sp_list[p_id]) + sampling_params_list = sp_list + logger.warning( + "[%s] PD mode: auto-duplicated thinker sampling params " + "for decode stage %d. To suppress this warning, pass " + "%d sampling params (one per physical stage).", + self._name, + d_id, + len(self.stage_list), + ) + if len(sampling_params_list) != len(self.stage_list): raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") @@ -330,11 +346,29 @@ async def generate( req_state.metrics = metrics self.request_states[request_id] = req_state sp0: SamplingParams = sampling_params_list[0] # type: ignore[index] + + # PD disaggregation: prepare prefill-only sampling params for + # stage-0 (max_tokens=1, do_remote_decode=True). + if self._pd_separation_pair is not None and self._pd_separation_pair[0] == 0: + sp0 = self._prepare_prefill_sampling_params(request_id, sp0) + logger.info( + "[%s] PD prefill SP prepared for req %s: max_tokens=%s, extra_args keys=%s, kv_transfer_params=%s", + self._name, + request_id, + sp0.max_tokens, + list(sp0.extra_args.keys()) if sp0.extra_args else None, + sp0.extra_args.get("kv_transfer_params") if sp0.extra_args else None, + ) + task = { "request_id": request_id, "engine_inputs": prompt, "sampling_params": sp0, } + # PD: store kv_transfer_params as top-level backup in task dict + # to survive any potential msgspec.Struct pickle serialization issues + if sp0.extra_args and "kv_transfer_params" in sp0.extra_args: + task["_kv_transfer_params"] = sp0.extra_args["kv_transfer_params"] self.stage_list[0].submit(task) metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() _req_start_ts[request_id] = time.time() @@ -460,10 +494,104 @@ async def _process_sequential_results( next_stage_id = stage_id + 1 if next_stage_id <= final_stage_id_for_e2e: next_stage: OmniStage = self.stage_list[next_stage_id] - # Derive inputs for the next stage, record postprocess time - with metrics.stage_postprocess_timer(stage_id, request_id): - next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt) - sp_next: SamplingParams = sampling_params_list[next_stage_id] + + # PD disaggregation: route from prefill → decode with + # original prompt and decode-side kv_transfer_params. + is_pd_routing = self._pd_separation_pair is not None and self._pd_separation_pair == ( + stage_id, + next_stage_id, + ) + + if is_pd_routing: + # Trace prefill engine outputs for PD debugging + for _eo in engine_outputs: + _eo_kv = getattr(_eo, "kv_transfer_params", None) + _eo_ntoks = ( + sum(len(o.token_ids) for o in _eo.outputs) + if hasattr(_eo, "outputs") and _eo.outputs + else "?" + ) + logger.debug( + "[%s][PD] Prefill stage-%d output for req %s: num_output_tokens=%s, kv_transfer_params=%s", + self._name, + stage_id, + request_id, + _eo_ntoks, + _eo_kv, + ) + next_inputs = [prompt] if not isinstance(prompt, list) else prompt + sp_next = sampling_params_list[next_stage_id].clone() + if sp_next.extra_args is None: + sp_next.extra_args = {} + + # Merge order (matches sync path in omni.py): + # 1. Start with role flags + # 2. Merge user-provided params + # 3. Merge config connector info + # 4. Merge prefill output params + # 5. Re-assert role flags + decode_kv_params: dict[str, Any] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "transfer_id": f"xfer-{request_id}", + } + + # Merge any user-provided decode-side kv_transfer_params + # first (same semantics as the sync path in omni.py). + existing_kv_params = self._normalize_kv_transfer_params( + sp_next.extra_args.get("kv_transfer_params") + ) + if existing_kv_params: + decode_kv_params.update(existing_kv_params) + + # Add prefill engine connection info from config + # (only fill in keys that aren't already present). + if self._pd_connector_info: + eid = self._pd_connector_info.get("prefill_engine_id") + if eid is not None and "remote_engine_id" not in decode_kv_params: + decode_kv_params["remote_engine_id"] = eid + baddr = self._pd_connector_info.get("prefill_bootstrap_addr") + if baddr is not None and "remote_bootstrap_addr" not in decode_kv_params: + decode_kv_params["remote_bootstrap_addr"] = baddr + + kv_from_prefill = self._extract_kv_transfer_params(engine_outputs) + if kv_from_prefill: + decode_kv_params.update(kv_from_prefill) + + # Ensure the decode role flags are correct after merges + decode_kv_params["do_remote_prefill"] = True + decode_kv_params["do_remote_decode"] = False + + sp_next.extra_args["kv_transfer_params"] = decode_kv_params + logger.debug( + "[%s] PD routing: stage-%d→stage-%d, req %s, " + "remote_request_id=%s, remote=%s:%s, " + "decode kv_transfer_params=%s", + self._name, + stage_id, + next_stage_id, + request_id, + decode_kv_params.get("remote_request_id", "NOT SET"), + decode_kv_params.get("remote_host", "?"), + decode_kv_params.get("remote_port", "?"), + decode_kv_params, + ) + if "remote_request_id" not in decode_kv_params: + logger.warning( + "[%s] PD routing: remote_request_id NOT SET " + "in decode_kv_params for req %s. The decode " + "engine's MooncakeConnector will use its own " + "request_id which differs from the prefill " + "engine's — KV transfer will FAIL. Apply the " + "mooncake_connector.py patch to fix this.", + self._name, + request_id, + ) + else: + # Derive inputs for the next stage, record postprocess time + with metrics.stage_postprocess_timer(stage_id, request_id): + next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt) + sp_next: SamplingParams = sampling_params_list[next_stage_id] # Check if we have a connector for this edge connector_key = (str(stage_id), str(next_stage_id)) diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 58dd88eb09..9231960407 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -9,6 +9,7 @@ import weakref from collections.abc import Callable, Generator, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict, is_dataclass from typing import Any, Literal, overload import huggingface_hub @@ -34,7 +35,6 @@ get_ray_queue_class, try_close_ray, ) -from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.entrypoints.omni_stage import OmniStage from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load @@ -51,6 +51,10 @@ ) from vllm_omni.outputs import OmniRequestOutput +# Default port for Mooncake KV transfer bootstrap service. +# Used when ``mooncake_bootstrap_port`` is not set in kv_connector_extra_config. +_DEFAULT_MOONCAKE_BOOTSTRAP_PORT = 25201 + logger = init_logger(__name__) @@ -182,6 +186,26 @@ def __init__(self, model: str, **kwargs: Any) -> None: logger.info(f"Initializing stages for model: {model}") self._initialize_stages(model, kwargs) + # PD disaggregation: detect prefill-decode stage pair + self._pd_separation_pair: tuple[int, int] | None = self._detect_pd_separation() + self._pd_connector_info: dict[str, Any] | None = None + self._pd_kv_params_by_req: dict[str, dict[str, Any]] = {} + # Lock protects _pd_kv_params_by_req for the async path (AsyncOmni) + # where store and pop may run from different coroutines. In the sync + # path (_run_generation) store and pop happen sequentially in the same + # thread, but the lock is harmless and keeps the code uniform. + self._pd_kv_params_lock = threading.Lock() + if self._pd_separation_pair is not None: + self._validate_pd_separation_config() + self._pd_connector_info = self._get_pd_connector_info() + p_id, d_id = self._pd_separation_pair + logger.info( + "[%s] PD disaggregation detected: prefill=stage-%d, decode=stage-%d", + self._name, + p_id, + d_id, + ) + def _get_default_cache_config(self, cache_backend: str | None) -> dict[str, Any] | None: if cache_backend == "cache_dit": return { @@ -764,6 +788,266 @@ def _name(self) -> str: def is_async(self) -> bool: return False + # ------------------------------------------------------------------ + # PD (Prefill-Decode) disaggregation helpers + # ------------------------------------------------------------------ + + def _detect_pd_separation(self) -> tuple[int, int] | None: + """Scan stage_list for a prefill-only / decode-only pair. + + Returns: + ``(prefill_stage_id, decode_stage_id)`` if found, else ``None``. + + Raises: + ValueError: If multiple PD pairs are detected (not supported). + """ + # Single pass: collect prefill stages keyed by both list index and + # stage_id so decode stages can match against either. + prefill_by_id: dict[int, int] = {} # stage_id_or_index → list index + decode_indices: list[int] = [] + for i, stage in enumerate(self.stage_list): + if getattr(stage, "is_prefill_only", False): + prefill_by_id[i] = i + sid = getattr(stage, "stage_id", i) + if sid != i: + prefill_by_id[sid] = i + if getattr(stage, "is_decode_only", False): + decode_indices.append(i) + + # Match decode stages to prefill stages via engine_input_source. + # This is O(d*s) where d = number of decode stages (typically 1) + # and s = number of source IDs per decode stage (typically 1..2). + pd_pairs: list[tuple[int, int]] = [] + for j in decode_indices: + source_ids = getattr(self.stage_list[j], "engine_input_source", []) + for src in source_ids: + if src in prefill_by_id: + pd_pairs.append((prefill_by_id[src], j)) + break + + if len(pd_pairs) > 1: + raise ValueError( + f"Multiple PD pairs detected ({pd_pairs}); only a single PD pair per pipeline is supported" + ) + return pd_pairs[0] if pd_pairs else None + + @staticmethod + def _to_dict(obj: Any, default: Any = None) -> dict[str, Any] | None: + """Convert *obj* to a plain ``dict``, trying several strategies. + + Returns *default* when *obj* is ``None`` or conversion fails. + Typical usage:: + + self._to_dict(kv_cfg, default={}) # replaces _kv_cfg_to_dict + self._to_dict(kv_params) # replaces _normalize_kv_transfer_params + """ + if obj is None: + return default + if isinstance(obj, dict): + return obj + if is_dataclass(obj): + try: + return asdict(obj) + except Exception: + return default + for attr in ("model_dump", "dict"): + if hasattr(obj, attr): + try: + return getattr(obj, attr)() + except Exception: + pass + if hasattr(obj, "items"): + try: + return dict(obj) + except Exception: + pass + try: + return dict(obj) + except Exception: + try: + return vars(obj) + except Exception: + logger.debug( + "Unable to convert object of type %s to dict", + type(obj), + ) + return default + + # Intentional thin wrappers over _to_dict with different defaults: + # _kv_cfg_to_dict returns {} (never None) for safe .get() chains. + # _normalize_kv_transfer_params returns None when absent (caller checks). + def _kv_cfg_to_dict(self, kv_cfg: Any) -> dict[str, Any]: + return self._to_dict(kv_cfg, default={}) or {} + + def _normalize_kv_transfer_params(self, kv_params: Any) -> dict[str, Any] | None: + return self._to_dict(kv_params) + + def _validate_pd_separation_config(self) -> None: + """Validate that PD stage configurations are consistent.""" + assert self._pd_separation_pair is not None + p_id, d_id = self._pd_separation_pair + p_stage = self.stage_list[p_id] + d_stage = self.stage_list[d_id] + + def _get_kv_cfg(stage: OmniStage) -> dict[str, Any]: + ea = stage.engine_args + cfg = getattr(ea, "kv_transfer_config", None) + if cfg is None: + cfg = ea.get("kv_transfer_config", None) if hasattr(ea, "get") else None + if cfg is None: + raise ValueError( + f"Stage-{stage.stage_id} is marked for PD disaggregation " + "but has no 'kv_transfer_config' in engine_args" + ) + cfg_dict = self._kv_cfg_to_dict(cfg) + if not cfg_dict: + raise ValueError( + f"Stage-{stage.stage_id} kv_transfer_config ({type(cfg).__name__}) could not be parsed into a dict" + ) + return cfg_dict + + p_cfg = _get_kv_cfg(p_stage) + d_cfg = _get_kv_cfg(d_stage) + + p_role = p_cfg.get("kv_role") + d_role = d_cfg.get("kv_role") + if p_role not in ("kv_producer", "kv_both"): + raise ValueError(f"Prefill stage-{p_id} kv_role must be 'kv_producer' or 'kv_both', got '{p_role}'") + if d_role not in ("kv_consumer", "kv_both"): + raise ValueError(f"Decode stage-{d_id} kv_role must be 'kv_consumer' or 'kv_both', got '{d_role}'") + + d_sources = list(getattr(d_stage, "engine_input_source", []) or []) + if p_id not in d_sources and p_stage.stage_id not in d_sources: + raise ValueError(f"Decode stage-{d_id} must list prefill stage-{p_id} in engine_input_source") + + p_conn = p_cfg.get("kv_connector") + d_conn = d_cfg.get("kv_connector") + if p_conn != d_conn: + raise ValueError(f"PD connector mismatch: prefill uses '{p_conn}', decode uses '{d_conn}'") + if not p_conn: + raise ValueError("PD disaggregation requires kv_connector to be set in kv_transfer_config") + + for key in ("kv_buffer_device", "kv_buffer_size"): + p_val = p_cfg.get(key) + d_val = d_cfg.get(key) + if p_val is not None and d_val is not None and p_val != d_val: + raise ValueError(f"PD {key} mismatch: prefill uses '{p_val}', decode uses '{d_val}'") + + # Validate tensor_parallel_size matches between prefill and decode + p_tp = getattr(getattr(p_stage, "engine_args", None), "tensor_parallel_size", 1) + d_tp = getattr(getattr(d_stage, "engine_args", None), "tensor_parallel_size", 1) + if p_tp != d_tp: + raise ValueError(f"PD stages must have matching tensor_parallel_size: prefill={p_tp}, decode={d_tp}") + + def _get_pd_connector_info(self) -> dict[str, Any] | None: + """Extract prefill engine KV connector info from stage config.""" + if self._pd_separation_pair is None: + return None + + p_id, _ = self._pd_separation_pair + p_stage = self.stage_list[p_id] + + ea = p_stage.engine_args + kv_cfg = getattr(ea, "kv_transfer_config", None) + if kv_cfg is None and hasattr(ea, "get"): + kv_cfg = ea.get("kv_transfer_config") + if kv_cfg is None: + return None + + kv_cfg_dict = self._kv_cfg_to_dict(kv_cfg) + if not kv_cfg_dict: + return None + + engine_id = kv_cfg_dict.get("engine_id") + kv_connector = str(kv_cfg_dict.get("kv_connector", "") or "") + extra_cfg = kv_cfg_dict.get("kv_connector_extra_config", {}) or {} + if not isinstance(extra_cfg, dict): + extra_cfg = self._kv_cfg_to_dict(extra_cfg) + + info: dict[str, Any] = {"prefill_engine_id": engine_id} + + if "mooncake" in kv_connector.lower(): + bootstrap_port = extra_cfg.get("mooncake_bootstrap_port", None) + if bootstrap_port is None: + bootstrap_port = _DEFAULT_MOONCAKE_BOOTSTRAP_PORT + kv_ip = kv_cfg_dict.get("kv_ip") or "127.0.0.1" + info["prefill_bootstrap_addr"] = f"{kv_ip}:{bootstrap_port}" + + logger.debug("[%s] PD connector info: %s", self._name, info) + return info + + def _prepare_prefill_sampling_params(self, req_id: str, sp: SamplingParams) -> SamplingParams: + sp = sp.clone() + sp.max_tokens = 1 + if hasattr(sp, "min_tokens"): + try: + sp.min_tokens = 1 + except Exception: + pass + # Neutralize stop conditions so the prefill always finishes with + # finish_reason='length' (not 'stop'). MooncakeConnector cancels + # KV transfer for any reason other than FINISHED_LENGTH_CAPPED. + sp.stop = [] + sp.stop_token_ids = [] + sp.include_stop_str_in_output = False + if sp.extra_args is None: + sp.extra_args = {} + kv_params = self._normalize_kv_transfer_params(sp.extra_args.get("kv_transfer_params")) + merged: dict[str, Any] = {} + if kv_params: + merged.update(kv_params) + merged.update( + { + "do_remote_decode": True, + "do_remote_prefill": False, + "transfer_id": f"xfer-{req_id}", + } + ) + sp.extra_args["kv_transfer_params"] = merged + logger.debug( + "[PD] _prepare_prefill_sampling_params: req=%s max_tokens=%s kv_transfer_params=%s extra_args_id=%s", + req_id, + sp.max_tokens, + merged, + id(sp.extra_args), + ) + return sp + + def _pop_pd_kv_params(self, req_id: str, fallback: Any | None = None) -> dict[str, Any] | None: + kv_params = self._normalize_kv_transfer_params(fallback) + with self._pd_kv_params_lock: + stored = self._pd_kv_params_by_req.pop(req_id, None) + if kv_params is None: + kv_params = stored + return kv_params + + def _drop_pd_kv_params(self, req_id: str) -> None: + with self._pd_kv_params_lock: + self._pd_kv_params_by_req.pop(req_id, None) + + def _extract_kv_transfer_params(self, engine_outputs: Any) -> dict[str, Any] | None: + """Extract kv_transfer_params from already-loaded engine outputs. + + Called after engine outputs have been deserialized from IPC so that + shared memory is only read once. + + Note: Whether MooncakeConnector propagates kv_transfer_params back + via EngineCoreOutput depends on the vLLM version. Some versions + return ``None`` from ``request_finished()`` while others return + actual params (remote_host, remote_port, etc.). When available, + these are merged into the decode-side kv_transfer_params. + """ + outputs = engine_outputs if isinstance(engine_outputs, list) else [engine_outputs] + for output in outputs: + kv_params = getattr(output, "kv_transfer_params", None) + if kv_params is not None: + logger.debug( + "[PD] Extracted kv_transfer_params from engine output: %s", + kv_params, + ) + return self._normalize_kv_transfer_params(kv_params) + return None + class Omni(OmniBase): """Unified entrypoint for both LLM and Diffusion models for better usability. @@ -908,6 +1192,22 @@ def _run_generation( if sampling_params_list is None: raise ValueError("sampling_params_list is required for pipelined generation") + # PD disaggregation: the user provides sampling params for the + # *logical* stages (thinker, talker, code2wav) but the PD config + # splits thinker into two physical stages (prefill + decode). + # Auto-duplicate the thinker params for the decode stage so the + # caller doesn't need to know about the internal split. + if self._pd_separation_pair is not None and len(sampling_params_list) == len(self.stage_list) - 1: + p_id, d_id = self._pd_separation_pair + sp_list = list(sampling_params_list) + sp_list.insert(d_id, sp_list[p_id]) + sampling_params_list = sp_list + logger.debug( + "[%s] PD mode: auto-duplicated thinker sampling params for decode stage %d", + self._name, + d_id, + ) + if len(sampling_params_list) != len(self.stage_list): raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") @@ -936,13 +1236,6 @@ def _run_generation( _req_start_ts: dict[str, float] = {} _wall_start_ts: float = time.time() - # CFG companion tracking (prompt expansion + lifecycle management) - cfg = CfgCompanionTracker( - prompt_expand_func=getattr(self.stage_list[0], "prompt_expand_func", None), - stage0_sampling_params=sampling_params_list[0], - ) - expanded_companions = cfg.expand_prompts(request_id_to_prompt) - # Determine the final stage for E2E stats (highest stage_id with final_output=True; fallback to last stage) final_stage_id_to_prompt: dict[str, int] = {} for rid, prompt in request_id_to_prompt.items(): @@ -973,8 +1266,18 @@ def _run_generation( # Mark first input time for stage-0 metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() + # Check if stage 0 is the prefill-only stage in a PD pair + _seed_is_prefill = self._pd_separation_pair is not None and self._pd_separation_pair[0] == 0 + for req_id, prompt in request_id_to_prompt.items(): sp0 = sampling_params_list[0] # type: ignore[index] + + if _seed_is_prefill: + # PD disaggregation: prefill-only stage generates a single + # token so vLLM's KV connector saves the KV cache. + # Aligned with vLLM's disaggregated serving proxy pattern. + sp0 = self._prepare_prefill_sampling_params(req_id, sp0) + task = { "request_id": req_id, "engine_inputs": prompt, @@ -984,18 +1287,6 @@ def _run_generation( _req_start_ts[req_id] = time.time() logger.debug(f"[{self._name}] Enqueued request {req_id} to stage-0") - # Submit CFG companion requests to stage-0 - if cfg.is_active: - for companion_id, companion_prompt in expanded_companions: - task = { - "request_id": companion_id, - "engine_inputs": companion_prompt, - "sampling_params": cfg.stage0_sampling_params, - } - self.stage_list[0].submit(task) - _req_start_ts[companion_id] = time.time() - logger.debug(f"[{self._name}] Enqueued CFG companion {companion_id} to stage-0") - pbar = None if use_tqdm: tqdm_func = use_tqdm if callable(use_tqdm) else tqdm @@ -1007,7 +1298,7 @@ def _run_generation( ) # For each stage, forward results to next stage; collect finals at the end # We pipeline by continually polling output queues in stage order - remaining_by_stage: list[int] = [len(request_prompts) + cfg.num_companions] + [0] * (num_stages - 1) + remaining_by_stage: list[int] = [len(request_prompts)] + [0] * (num_stages - 1) completed_requests = 0 total_requests = len(request_prompts) @@ -1024,17 +1315,11 @@ def _run_generation( made_progress = True req_id = result.get("request_id") if "error" in result: + if req_id is not None: + self._drop_pd_kv_params(req_id) logger.error( f"[{self._name}] Stage {stage_id} error on request {req_id}: {result['error']}", ) - if cfg.is_companion(req_id) and stage_id == 0: - parent_id, parent_aborted = cfg.on_companion_error(req_id) - if parent_aborted: - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {parent_id} aborted due to " - f"companion failure ({completed_requests}/{total_requests})", - ) continue if result.get("type") == "stage_ready": @@ -1043,31 +1328,19 @@ def _run_generation( time.sleep(0.05) continue - # CFG: companion requests only run through Stage-0 - if cfg.is_companion(req_id) and stage_id == 0: - ready_parent = cfg.on_companion_completed(req_id) - if ready_parent is not None: - success = cfg.forward_parent_with_cfg( - ready_parent, - cfg.pop_pending_parent(ready_parent), - self.stage_list, - self.connectors, - sampling_params_list, - request_id_to_prompt, - final_stage_id_to_prompt, - metrics, - remaining_by_stage, - ) - if not success: - cfg.consume_parent_failure(ready_parent) - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {ready_parent} dropped due to CFG forwarding failure " - f"({completed_requests}/{total_requests})", - ) - continue - engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") + + # PD: extract kv_transfer_params from prefill engine outputs + # (must happen after _load so shared memory is read only once). + if ( + self._pd_separation_pair is not None + and req_id is not None + and stage_id == self._pd_separation_pair[0] + ): + kv_params = self._extract_kv_transfer_params(engine_outputs) + if kv_params is not None: + with self._pd_kv_params_lock: + self._pd_kv_params_by_req[req_id] = kv_params # Mark last output time for this stage whenever we receive outputs metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) try: @@ -1156,55 +1429,101 @@ def _run_generation( next_stage_id = stage_id + 1 if next_stage_id <= final_stage_id_to_prompt[req_id]: - # CFG: if this parent has companions, defer forwarding - if cfg.has_companions(req_id) and stage_id == 0: - if cfg.is_parent_failed(req_id): - cfg.consume_parent_failure(req_id) - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {req_id} skipped CFG forwarding due to " - f"companion failure ({completed_requests}/{total_requests})", - ) - continue + next_stage: OmniStage = self.stage_list[next_stage_id] - if cfg.all_companions_done(req_id): - success = cfg.forward_parent_with_cfg( + # PD disaggregation: when routing from prefill to decode, + # re-submit the original prompt so the decode engine can + # load the prefilled KV cache via vLLM's native connector. + is_pd_routing = self._pd_separation_pair is not None and self._pd_separation_pair == ( + stage_id, + next_stage_id, + ) + + if is_pd_routing: + # Use the original prompt as decode engine input + original_prompt = request_id_to_prompt[req_id] + next_inputs = [original_prompt] if not isinstance(original_prompt, list) else original_prompt + + sp_next = sampling_params_list[next_stage_id] # type: ignore[index] + sp_next = sp_next.clone() + if sp_next.extra_args is None: + sp_next.extra_args = {} + + # Build decode-side kv_transfer_params. The decode + # engine needs ``do_remote_prefill=True`` so the KV + # connector loads the remotely prefilled KV cache. + decode_kv_params: dict[str, Any] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "transfer_id": f"xfer-{req_id}", + } + + # Merge any user-provided kv_transfer_params first + existing_kv_params = self._normalize_kv_transfer_params( + sp_next.extra_args.get("kv_transfer_params") + ) + if existing_kv_params: + decode_kv_params.update(existing_kv_params) + + # Add prefill engine connection info from config (if missing) + if self._pd_connector_info: + eid = self._pd_connector_info.get("prefill_engine_id") + if eid is not None and "remote_engine_id" not in decode_kv_params: + decode_kv_params["remote_engine_id"] = eid + baddr = self._pd_connector_info.get("prefill_bootstrap_addr") + if baddr is not None and "remote_bootstrap_addr" not in decode_kv_params: + decode_kv_params["remote_bootstrap_addr"] = baddr + + # If the prefill output carried connector metadata, + # merge it in (some connectors return additional info). + kv_params_from_output = self._pop_pd_kv_params(req_id, result.get("kv_transfer_params")) + if kv_params_from_output: + decode_kv_params.update(kv_params_from_output) + + # Ensure the decode role flags are correct + decode_kv_params["do_remote_prefill"] = True + decode_kv_params["do_remote_decode"] = False + if not decode_kv_params.get("transfer_id"): + decode_kv_params["transfer_id"] = f"xfer-{req_id}" + + sp_next.extra_args["kv_transfer_params"] = decode_kv_params + logger.info( + "[%s] PD routing: stage-%d→stage-%d, req %s, remote_request_id=%s, remote=%s:%s", + self._name, + stage_id, + next_stage_id, + req_id, + decode_kv_params.get("remote_request_id", "NOT SET"), + decode_kv_params.get("remote_host", "?"), + decode_kv_params.get("remote_port", "?"), + ) + if "remote_request_id" not in decode_kv_params: + logger.warning( + "[%s] PD routing: remote_request_id NOT SET " + "in decode_kv_params for req %s. Apply the " + "mooncake_connector.py patch to fix this.", + self._name, req_id, - {"engine_outputs": engine_outputs, "stage_id": stage_id}, - self.stage_list, - self.connectors, - sampling_params_list, - request_id_to_prompt, - final_stage_id_to_prompt, - metrics, - remaining_by_stage, ) - if not success: - cfg.consume_parent_failure(req_id) - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {req_id} dropped due to CFG forwarding failure " - f"({completed_requests}/{total_requests})", + else: + try: + # Derive inputs for the next stage, record preprocess time + with metrics.stage_postprocess_timer(stage_id, req_id): + next_inputs = next_stage.process_engine_inputs( + self.stage_list, [request_id_to_prompt[req_id]] ) - else: - cfg.defer_parent(req_id, engine_outputs, stage_id) - continue - - next_stage: OmniStage = self.stage_list[next_stage_id] - try: - # Derive inputs for the next stage, record preprocess time - with metrics.stage_postprocess_timer(stage_id, req_id): - next_inputs = next_stage.process_engine_inputs( - self.stage_list, [request_id_to_prompt[req_id]] + except Exception as e: + logger.exception( + f"[{self._name}] Process engine inputs error for req {req_id}" + f" at stage {next_stage_id}: {e}", ) - except Exception as e: - completed_requests += 1 - logger.exception( - f"[{self._name}] Process engine inputs error for req {req_id}" - f" at stage {next_stage_id}: {e} ({completed_requests}/{total_requests})", - ) - continue - sp_next = sampling_params_list[next_stage_id] # type: ignore[index] + continue + sp_next = sampling_params_list[next_stage_id] # type: ignore[index] + + # If we are about to enter the prefill stage (when it is not stage-0), + # apply prefill-only sampling params. + if self._pd_separation_pair is not None and next_stage_id == self._pd_separation_pair[0]: + sp_next = self._prepare_prefill_sampling_params(req_id, sp_next) # Check if we have a connector for this edge connector_key = (str(stage_id), str(next_stage_id)) @@ -1228,13 +1547,14 @@ def _run_generation( f"[{self._name}] Failed to send request {req_id} to stage-{next_stage_id} via connector. " "Configure a connector for this edge or inspect connector logs for details." ) - logger.debug( f"[{self._name}] Forwarded request {req_id} to stage-{next_stage_id}", ) remaining_by_stage[next_stage_id] += 1 else: completed_requests += 1 + if req_id is not None: + self._drop_pd_kv_params(req_id) if pbar: final_mod = self.output_modalities[final_stage_id_to_prompt[req_id]] pbar.unit = "img" if final_mod == "image" else "req" @@ -1242,12 +1562,6 @@ def _run_generation( logger.debug( f"[{self._name}] Request {req_id} fully completed ({completed_requests}/{total_requests})", ) - for timed_out_id in cfg.check_timeouts(): - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {timed_out_id} timed out; counting as failed " - f"({completed_requests}/{total_requests})", - ) if not made_progress: time.sleep(0.005) @@ -1256,6 +1570,11 @@ def _run_generation( if pbar: pbar.close() + # Defense-in-depth: drop any leftover PD KV params for this batch + # in case error/completion paths missed a cleanup. + for rid in request_ids: + self._drop_pd_kv_params(rid) + # Summarize and print stats try: metrics.build_and_log_summary() diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index ec9248e304..f797c51937 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -215,9 +215,11 @@ def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[Re total_out_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() + for output in step_outputs: if output.finished: outputs.append(output) + if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput @@ -239,4 +241,70 @@ def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[Re # Sort the outputs by the int part of request ID which is in format of 'int-uuid'. # This is necessary because some requests may be finished earlier than # its previous requests. + + # PD disaggregation: flush any pending KV connector sends that were + # added by request_finished() after the last build_connector_meta() + # call. Without this flush, the prefill engine's worker never + # receives the block IDs needed for KV transfer to the decode engine. + self._flush_kv_connector_sends() + return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0])) + + def _flush_kv_connector_sends(self) -> None: + """Flush pending KV connector send metadata to workers. + + When _run_engine() finishes a batch, request_finished() may have + added entries to _reqs_need_send *after* the last + build_connector_meta() call within that step's schedule(). In + standard vLLM online serving this is not a problem because the + engine loop continues and the next schedule() picks them up. In + OmniLLM batch mode the loop exits immediately, so we must run one + more empty-request step to deliver the metadata. + + NOTE: This method reaches into vLLM internals + (engine_core → scheduler → connector → connector_scheduler → + _reqs_need_send) and fabricates a SchedulerOutput to call + execute_model() directly. There is no public Engine API to + "flush pending KV sends" because vLLM assumes the engine loop + runs continuously. OmniLLM's batch mode breaks that assumption. + If upstream refactors rename or restructure these internals, this + method will need to be updated. Tested against vLLM >= 0.8.x. + """ + try: + engine_core = getattr(self.llm_engine, "engine_core", None) + if engine_core is None: + return + scheduler = getattr(engine_core, "scheduler", None) + if scheduler is None: + return + connector = getattr(scheduler, "connector", None) + if connector is None: + return + cs = getattr(connector, "connector_scheduler", None) + if cs is None: + return + + pending = getattr(cs, "_reqs_need_send", None) + if not pending: + return + + from vllm.v1.core.sched.output import SchedulerOutput + + # Create an empty scheduler output and attach connector metadata. + so = SchedulerOutput.make_empty() + so.kv_connector_metadata = connector.build_connector_meta(so) + + # Run an empty model step: the worker sees + # total_num_scheduled_tokens == 0 and takes the no_forward() + # path, which only processes connector metadata + # (record_send_reqs → sets ready event on SendBlockMeta). + model_executor = getattr(engine_core, "model_executor", None) + if model_executor is None: + return + model_executor.execute_model(so) + logger.debug("[OmniLLM] KV connector sends flushed") + except Exception: + logger.warning( + "[OmniLLM] Failed to flush KV connector sends", + exc_info=True, + ) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 098cfa15d8..5fce946332 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -47,7 +47,6 @@ _resolve_model_tokenizer_paths, _to_dict, is_profiler_task, - load_func_from_config, maybe_dump_to_shm, set_stage_devices, ) @@ -277,6 +276,9 @@ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): and self.stage_id is not None ): stage_config.engine_args.stage_id = self.stage_id + # PD disaggregation flags + self.is_prefill_only: bool = getattr(stage_config, "is_prefill_only", False) + self.is_decode_only: bool = getattr(stage_config, "is_decode_only", False) if hasattr(stage_config, "custom_process_input_func"): # Import the module specified in the config (already a full module path) module_path, func_name = stage_config.custom_process_input_func.rsplit(".", 1) @@ -285,7 +287,6 @@ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): else: self.custom_process_input_func = None - self.prompt_expand_func = load_func_from_config(getattr(stage_config, "prompt_expand_func", None)) self.final_output = getattr(stage_config, "final_output", False) self.final_output_type = getattr(stage_config, "final_output_type", None) self.tts_args = _to_dict(getattr(stage_config, "tts_args", {})) @@ -466,9 +467,10 @@ def init_stage_worker( "connectors_config": connectors_config or {}, "stage_type": self.stage_type, "engine_input_source": self.engine_input_source, - "cfg_kv_collect_func": getattr(self.stage_config, "cfg_kv_collect_func", None), "final_output": self.final_output, "final_output_type": self.final_output_type, + "is_prefill_only": self.is_prefill_only, + "is_decode_only": self.is_decode_only, } try: old_env = os.environ.get("VLLM_LOGGING_PREFIX") @@ -603,6 +605,19 @@ def _inject_global_id(target_ein): else: _inject_global_id(ein) + # PD safeguard: store kv_transfer_params as a plain-dict backup in the + # payload so it definitely survives pickle even if the msgspec.Struct + # extra_args field is silently dropped (observed with omit_defaults=True + # in SamplingParams). The backup fires only when _kv_transfer_params + # is not already in the payload, so overhead is minimal. + # TODO: open vLLM issue if confirmed reproducible + if "_kv_transfer_params" not in payload: + sp = payload.get("sampling_params") + if sp is not None and hasattr(sp, "extra_args") and sp.extra_args: + kv_tp = sp.extra_args.get("kv_transfer_params") + if kv_tp is not None: + payload["_kv_transfer_params"] = dict(kv_tp) + self._in_q.put(payload) def try_collect(self) -> dict[str, Any] | None: @@ -611,6 +626,8 @@ def try_collect(self) -> dict[str, Any] | None: Returns: Result dictionary if available, None otherwise. Result contains request_id, engine_outputs (or engine_outputs_shm), and metrics. + For prefill-only stages, also includes kv_transfer_params extracted + from vLLM's RequestOutput if available. """ assert self._out_q is not None try: @@ -619,11 +636,7 @@ def try_collect(self) -> dict[str, Any] | None: return None def process_engine_inputs( - self, - stage_list: list[Any], - prompt: OmniTokensPrompt | TextPrompt = None, - *, - source_outputs_override: Any = None, + self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None ) -> list[OmniTokensPrompt | TextPrompt]: """Process engine inputs for this stage from upstream stage outputs. @@ -634,8 +647,6 @@ def process_engine_inputs( Args: stage_list: List of all stages in the pipeline prompt: Optional original prompt (for multimodal data preservation) - source_outputs_override: Use these outputs instead of reading from - the source stage's ``engine_outputs`` (for deferred CFG requests). Returns: List of processed engine inputs ready for this stage @@ -648,11 +659,7 @@ def process_engine_inputs( if len(self.engine_input_source) == 0: raise ValueError("engine_input_source is empty") source_stage_id = self.engine_input_source[0] - source_outputs = ( - source_outputs_override - if source_outputs_override is not None - else stage_list[source_stage_id].engine_outputs - ) + source_outputs = stage_list[source_stage_id].engine_outputs if not isinstance(prompt, list): prompt = [prompt] multi_modal_data = { @@ -674,18 +681,6 @@ def process_engine_inputs( else: engine_input_source = self.engine_input_source - if source_outputs_override is not None and engine_input_source: - # Temporarily swap engine_outputs so custom_process_input_func - # (which reads stage_list directly) sees the correct data. - _source_id = engine_input_source[0] - _orig_outputs = stage_list[_source_id].engine_outputs - stage_list[_source_id].engine_outputs = source_outputs_override - try: - return self.custom_process_input_func( - stage_list, engine_input_source, prompt, self.requires_multimodal_data - ) - finally: - stage_list[_source_id].engine_outputs = _orig_outputs return self.custom_process_input_func( stage_list, engine_input_source, prompt, self.requires_multimodal_data ) @@ -711,6 +706,17 @@ def _stage_worker( from vllm_omni.plugins import load_omni_general_plugins load_omni_general_plugins() + + # -- PD disaggregation: monkey-patch MooncakeConnector before engine init -- + _is_prefill_only = stage_payload.get("is_prefill_only", False) + _is_decode_only = stage_payload.get("is_decode_only", False) + if _is_prefill_only or _is_decode_only: + _kv_cfg = stage_payload.get("engine_args", {}).get("kv_transfer_config", {}) + _engine_id = _kv_cfg.get("engine_id") if isinstance(_kv_cfg, dict) else None + from vllm_omni.distributed.kv_transfer.monkey_patch import apply_mooncake_connector_patch + + apply_mooncake_connector_patch(engine_id=_engine_id) + # IMPORTANT: Ensure vLLM's internal multiprocessing workers (e.g., GPUARWorker / # GPUARModelRunner) are spawned with a fork-safe method. # Mooncake / gRPC / RDMA and CUDA/NCCL can deadlock under fork-with-threads. @@ -731,8 +737,6 @@ def _stage_worker( connectors_config = stage_payload.get("connectors_config", {}) stage_type: Literal["llm", "diffusion"] = stage_payload.get("stage_type", "llm") - cfg_kv_collect_func = load_func_from_config(stage_payload.get("cfg_kv_collect_func")) - if stage_type != "diffusion": _resolve_worker_cls(engine_args) @@ -792,7 +796,6 @@ def _stage_worker( model=model, stage_id=stage_id, engine_input_source=stage_payload.get("engine_input_source", []), - cfg_kv_collect_func=cfg_kv_collect_func, **engine_args, ) else: @@ -928,6 +931,23 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: # Ensure that the popped tasks are with identical sampling params. Take one of them. batch_engine_sampling_params: OmniSamplingParams = batch_tasks[0]["sampling_params"] + # PD safeguard: if the task carries _kv_transfer_params (backup key + # stored by the orchestrator), ensure it's present in the SP's + # extra_args. msgspec.Struct pickle with omit_defaults=True may + # silently drop the extra_args dict in some environments. + _kv_backup = batch_tasks[0].get("_kv_transfer_params") + if _kv_backup is not None: + sp = batch_engine_sampling_params + if isinstance(sp, SamplingParams): + if sp.extra_args is None: + sp.extra_args = {} + if "kv_transfer_params" not in sp.extra_args: + sp.extra_args["kv_transfer_params"] = _kv_backup + logger.warning( + "[Stage-%d][PD] Restored kv_transfer_params from backup (pickle dropped extra_args)", + stage_id, + ) + batch_request_ids: list[Any] = [] batch_engine_inputs: list[OmniPromptType] = [] _rx_bytes_by_rid: dict[Any, int] = {} @@ -1009,6 +1029,21 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 logger.debug(f"Generate done: batch={len(batch_tasks)}, req_ids={batch_request_ids}, gen_ms={_gen_ms:.1f}") + # PD check: MooncakeConnector only sends KV when + # finish_reason is 'length' (FINISHED_LENGTH_CAPPED). + if _kv_backup is not None: + for ro in gen_outputs: + _ro_fr = ( + getattr(ro.outputs[0], "finish_reason", None) if hasattr(ro, "outputs") and ro.outputs else None + ) + if _ro_fr and str(_ro_fr) != "length": + logger.warning( + "[Stage-%d][PD] finish_reason=%s (not 'length') — KV transfer will be skipped for req %s", + stage_id, + _ro_fr, + ro.request_id, + ) + # Group outputs per request id with fallback req_to_outputs: dict[Any, list[Any]] = {rid: [] for rid in batch_request_ids} unmapped: list[Any] = [] @@ -1122,6 +1157,17 @@ async def _stage_worker_async( from vllm_omni.plugins import load_omni_general_plugins load_omni_general_plugins() + + # -- PD disaggregation: monkey-patch MooncakeConnector before engine init -- + _is_prefill_only = stage_payload.get("is_prefill_only", False) + _is_decode_only = stage_payload.get("is_decode_only", False) + if _is_prefill_only or _is_decode_only: + _kv_cfg = stage_payload.get("engine_args", {}).get("kv_transfer_config", {}) + _engine_id = _kv_cfg.get("engine_id") if isinstance(_kv_cfg, dict) else None + from vllm_omni.distributed.kv_transfer.monkey_patch import apply_mooncake_connector_patch + + apply_mooncake_connector_patch(engine_id=_engine_id) + # IMPORTANT: Ensure vLLM's internal multiprocessing workers (e.g., GPUARWorker / # GPUARModelRunner) are spawned with a fork-safe method. if _os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": @@ -1141,7 +1187,6 @@ async def _stage_worker_async( final_output = stage_payload.get("final_output", False) final_output_type = stage_payload.get("final_output_type", None) - cfg_kv_collect_func = load_func_from_config(stage_payload.get("cfg_kv_collect_func")) # Handle non-standard model directory structures (e.g., tokenizer in root, model in subdir) model = _resolve_model_tokenizer_paths(model, engine_args) @@ -1220,13 +1265,10 @@ async def _stage_worker_async( od_config["omni_kv_config"]["engine_input_source"] = stage_payload.get("engine_input_source", []) logger.debug(f"[Stage-%s] Initializing diffusion engine with config: {od_config}", stage_id) - _diffusion_kwargs = {k: v for k, v in engine_args.items() if k not in {"od_config", "model"}} - if cfg_kv_collect_func is not None: - _diffusion_kwargs["cfg_kv_collect_func"] = cfg_kv_collect_func stage_engine = AsyncOmniDiffusion( model=model, od_config=od_config, - **_diffusion_kwargs, + **{k: v for k, v in engine_args.items() if k not in {"od_config", "model"}}, ) vllm_config = None # Diffusion doesn't use vllm_config else: diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index f798a21d2e..edcaaf23b6 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -779,10 +779,62 @@ def _thinker_to_talker_prefill( Returns: (input_ids, input_embeds) for talker """ + # Pad thinker_embed / thinker_hidden to match thinker_result_ids length + # so that all downstream slices use consistent indices. The mismatch + # can happen when the thinker sequence includes generated tokens that + # don't have corresponding embeddings (e.g. in PD disaggregation). + # + # Safety note: Zero-padded positions are safe because the talker's + # ChatML-segment loop (below) only slices embeddings within + # im_start_index boundaries. The padded tail falls outside the last + # assistant segment and is never attended to. Additionally, the + # multimodal_mask only selects audio/image/video token positions, + # which always lie within the prompt (prefill) portion where real + # embeddings exist. + target_len = thinker_result_ids.shape[-1] + _PD_PAD_THRESHOLD = 512 + if thinker_embed.shape[0] < target_len: + pad_len = target_len - thinker_embed.shape[0] + if pad_len > _PD_PAD_THRESHOLD: + logger.warning( + "[PD] Unexpectedly large embed padding: %d tokens " + "(threshold=%d). This may indicate a bug in PD " + "disaggregation.", + pad_len, + _PD_PAD_THRESHOLD, + ) + thinker_embed = torch.cat( + ( + thinker_embed, + torch.zeros( + pad_len, thinker_embed.shape[1], device=thinker_embed.device, dtype=thinker_embed.dtype + ), + ), + dim=0, + ) + if thinker_hidden.shape[0] < target_len: + pad_len = target_len - thinker_hidden.shape[0] + if pad_len > _PD_PAD_THRESHOLD: + logger.warning( + "[PD] Unexpectedly large hidden padding: %d tokens " + "(threshold=%d). This may indicate a bug in PD " + "disaggregation.", + pad_len, + _PD_PAD_THRESHOLD, + ) + thinker_hidden = torch.cat( + ( + thinker_hidden, + torch.zeros( + pad_len, thinker_hidden.shape[1], device=thinker_hidden.device, dtype=thinker_hidden.dtype + ), + ), + dim=0, + ) im_start_indexes = torch.cat( ( torch.nonzero(input_ids[0] == self.config.im_start_token_id).squeeze(), - torch.tensor([thinker_result_ids.shape[-1]], device=input_ids.device, dtype=input_ids.dtype), + torch.tensor([target_len], device=input_ids.device, dtype=input_ids.dtype), ), dim=-1, ) # Shape [n_starts + 1]; Take batch 0 since batched inference is not supported here. @@ -903,8 +955,25 @@ def talker_preprocess_decode( return last_talker_hidden, text_step, update_dict def _get_talker_user_parts(self, im_start_index, segment_end_index, multimodal_mask, thinker_hidden, thinker_embed): + # Clamp segment_end_index to the shortest tensor so mask and embed + # slices always have the same length (guards against length mismatches + # between thinker_result_ids, thinker_embed, and thinker_hidden). + segment_end_index = min( + segment_end_index, + multimodal_mask.shape[0], + thinker_hidden.shape[0], + thinker_embed.shape[0], + ) + seg_len = segment_end_index - im_start_index + if seg_len <= 0: + return torch.empty( + (0, self.config.talker_config.text_config.hidden_size), + device=thinker_hidden.device, + dtype=torch.bfloat16, + ) + user_talker_part = torch.empty( - (segment_end_index - im_start_index, self.config.talker_config.text_config.hidden_size), + (seg_len, self.config.talker_config.text_config.hidden_size), device=thinker_hidden.device, dtype=torch.bfloat16, ) diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml new file mode 100644 index 0000000000..fbf1479dc9 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml @@ -0,0 +1,199 @@ +# Stage config for Qwen3-Omni-MoE with Prefill-Decode disaggregation +# +# Splits the thinker stage into separate prefill and decode instances. +# The prefill stage processes prompts and transfers KV cache to the +# decode stage via vLLM's native KV connector (e.g., MooncakeConnector). +# +# Stage 0: Thinker Prefill (prompt processing, KV producer) +# Stage 1: Thinker Decode (token generation, KV consumer) +# Stage 2: Talker (text embeddings -> RVQ codec codes) +# Stage 3: Code2Wav (RVQ codes -> audio waveform) +# +# Requirements: +# - A supported KV connector (MooncakeConnector, NixlConnector, etc.) +# - Prefill and decode stages must be able to communicate via the connector +# - Mooncake transfer engine must be installed if using MooncakeConnector +# +# The orchestrator overrides max_tokens=1 for the prefill stage so it +# performs only prompt processing + KV save, then the decode stage loads +# the KV cache and generates the full response. +# +# engine_id values must be set explicitly so the orchestrator can tell the +# decode engine where to pull KV from (the prefill engine's identity). +# +# IMPORTANT: MooncakeConnector does not support heterogeneous TP sizes. +# Both prefill and decode stages MUST use the same tensor_parallel_size. +# If the thinker model requires TP=2, set both stages to TP=2 and +# allocate 2 GPUs for each (e.g. devices "0,1" and "2,3", 5 GPUs total). +# +# Example layout on 3x H100-80G GPUs (TP=1 for both): +# GPU 0: Thinker Prefill +# GPU 1: Thinker Decode +# GPU 2: Talker + Code2Wav + +async_chunk: false +stage_args: + - stage_id: 0 + stage_type: llm + is_prefill_only: true + runtime: + devices: "0" + max_batch_size: 16 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 1 + kv_transfer_config: + kv_connector: "MooncakeConnector" + kv_role: "kv_producer" + kv_rank: 0 + kv_parallel_size: 2 + engine_id: "omni-thinker-prefill" + kv_connector_extra_config: + mooncake_bootstrap_port: 25201 + final_output: false + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: llm + is_decode_only: true + runtime: + devices: "1" + max_batch_size: 64 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 1 + kv_transfer_config: + kv_connector: "MooncakeConnector" + kv_role: "kv_consumer" + kv_rank: 1 + kv_parallel_size: 2 + engine_id: "omni-thinker-decode" + kv_connector_extra_config: + mooncake_bootstrap_port: 25202 + engine_input_source: [0] + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 2 + stage_type: llm + runtime: + devices: "2" + max_batch_size: 64 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + max_num_batched_tokens: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 3 + stage_type: llm + runtime: + devices: "2" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + engine_input_source: [2] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Runtime edges +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + + edges: + - from: 0 + to: 1 + window_size: -1 + - from: 1 + to: 2 + window_size: -1 + - from: 2 + to: 3 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 3a42159a8f..61d3e6a2b0 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -3,6 +3,7 @@ # Copyright 2025 The Qwen team. """Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition.""" +import logging from typing import Any import torch @@ -12,6 +13,16 @@ from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.inputs.data import OmniTokensPrompt +logger = logging.getLogger(__name__) + +# Pooling output layer keys used by the thinker → talker transition. +# "0" is always the word embedding layer. "24" corresponds to the talker's +# ``accept_hidden_layer`` config (``TalkerConfig.accept_hidden_layer``). +# If the model config changes this value, update _HIDDEN_LAYER_KEY accordingly +# or derive it dynamically from the stage config at initialisation time. +_EMBED_LAYER_KEY = "0" +_HIDDEN_LAYER_KEY = "24" + def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int: im_start_token_id = 151644 @@ -105,8 +116,8 @@ def thinker2talker_async_chunk( all_token_ids = _ensure_list(all_token_ids) prompt_token_ids = _ensure_list(prompt_token_ids) talker_additional_info = { - "thinker_embeddings": pooling_output.get("0").detach().cpu(), - "thinker_hidden_states": pooling_output.get("24").detach().cpu(), + "thinker_embeddings": pooling_output.get(_EMBED_LAYER_KEY).detach().cpu(), + "thinker_hidden_states": pooling_output.get(_HIDDEN_LAYER_KEY).detach().cpu(), "thinker_sequences": all_token_ids, "thinker_input_ids": prompt_token_ids, # Provide thinker-side TTS token embeddings for talker projection @@ -134,16 +145,81 @@ def thinker2talker_async_chunk( output_token_ids = _ensure_list(output_token_ids) talker_additional_info = { - "thinker_embeddings": pooling_output.get("0").detach().cpu(), + "thinker_embeddings": pooling_output.get(_EMBED_LAYER_KEY).detach().cpu(), + "thinker_hidden_states": pooling_output.get(_HIDDEN_LAYER_KEY).detach().cpu(), + "thinker_sequences": output_token_ids, "finished": torch.tensor(is_finished, dtype=torch.bool), } if not output_token_ids: # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. - talker_additional_info["thinker_hidden_states"] = pooling_output.get("24").detach().cpu() + talker_additional_info["thinker_hidden_states"] = pooling_output.get(_HIDDEN_LAYER_KEY).detach().cpu() return talker_additional_info +def _get_prefill_stage(stage_list: list[Any], source_stage_id: int) -> Any | None: + """Return the preceding prefill stage if PD disaggregation is active.""" + if source_stage_id <= 0: + return None + source_stage = stage_list[source_stage_id] + if not getattr(source_stage, "is_decode_only", False): + return None + prev_stage = stage_list[source_stage_id - 1] + if getattr(prev_stage, "is_prefill_only", False) and prev_stage.engine_outputs is not None: + return prev_stage + return None + + +def _merge_pd_embeddings( + decode_emb: torch.Tensor, + decode_hid: torch.Tensor, + prefill_mm: dict[str, Any], + device: torch.device, + expected_total: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Merge prefill prompt embeddings with decode generated embeddings. + + In PD disaggregation the decode engine only produces embeddings for the + tokens it actually computed. The prefill engine has embeddings for the + full prompt. We concatenate them, dynamically computing any overlap:: + + overlap = prefill_len + decode_len - expected_total + merged = prefill + decode[overlap:] + + When ``expected_total`` (= len(prompt_token_ids) + len(output.token_ids)) + is provided we use it to decide how many leading decode embeddings to + skip (they duplicate trailing prefill positions). If not provided we + fall back to no-skip concatenation. + """ + try: + p_emb = prefill_mm[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float) + p_hid = prefill_mm[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float) + except (KeyError, AttributeError, TypeError): + return decode_emb, decode_hid + + if p_emb.shape[0] == 0 or decode_emb.shape[0] == 0: + return decode_emb, decode_hid + + raw_total = p_emb.shape[0] + decode_emb.shape[0] + if expected_total is not None and raw_total > expected_total: + overlap = raw_total - expected_total + else: + overlap = 0 + + merged_emb = torch.cat([p_emb, decode_emb[overlap:]], dim=0) + merged_hid = torch.cat([p_hid, decode_hid[overlap:]], dim=0) + + logger.debug( + "[PD] Merged prefill(%d) + decode(%d) overlap=%d → %d embeddings (expected=%s)", + p_emb.shape[0], + decode_emb.shape[0], + overlap, + merged_emb.shape[0], + expected_total, + ) + return merged_emb, merged_hid + + def thinker2talker( stage_list: list[Any], engine_input_source: list[int], @@ -158,6 +234,12 @@ def thinker2talker( 2. Split hidden states into: prompt embeddings + generated embeddings 3. Package for talker with additional information + In PD disaggregation the decode engine's multimodal_output only covers + the tokens it computed (not the full prompt). When a preceding prefill + stage is detected we merge the prefill's prompt embeddings with the + decode's generated embeddings so the talker receives the complete + sequence. + Args: stage_list: List of stage objects engine_input_source: Source stage IDs (typically [0] for thinker) @@ -172,21 +254,67 @@ def thinker2talker( device = torch.device(current_platform.device_type) + # PD disaggregation: look for a preceding prefill stage whose + # embeddings we need to merge with the decode output. + source_stage_id = engine_input_source[0] + prefill_stage = _get_prefill_stage(stage_list, source_stage_id) + # Process each thinker output - for thinker_output in thinker_outputs: + for i, thinker_output in enumerate(thinker_outputs): output = thinker_output.outputs[0] + decode_emb = output.multimodal_output[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float) + decode_hid = output.multimodal_output[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float) + + # Expected total = prompt tokens + generated tokens (the full sequence). + expected_total = len(thinker_output.prompt_token_ids) + len(output.token_ids) + + logger.debug( + "[PD] thinker2talker: prompt_len=%d, output_len=%d, expected_total=%d, decode_emb=%d, decode_hid=%d", + len(thinker_output.prompt_token_ids), + len(output.token_ids), + expected_total, + decode_emb.shape[0], + decode_hid.shape[0], + ) + + # Merge prefill prompt embeddings when running in PD mode. + if prefill_stage is not None: + try: + prefill_eos = prefill_stage.engine_outputs + prefill_eo = prefill_eos[min(i, len(prefill_eos) - 1)] + prefill_mm = prefill_eo.outputs[0].multimodal_output + decode_emb, decode_hid = _merge_pd_embeddings( + decode_emb, + decode_hid, + prefill_mm, + device, + expected_total=expected_total, + ) + except Exception as exc: + logger.warning("[PD] Could not merge prefill embeddings: %s", exc) + + # Helper: get TTS embed from decode, fall back to prefill if missing. + def _tts(key: str) -> torch.Tensor: + val = output.multimodal_output.get(key) + if val is None and prefill_stage is not None: + try: + val = prefill_stage.engine_outputs[0].outputs[0].multimodal_output.get(key) + except Exception: + pass + return val.detach().to(device=device, dtype=torch.float) if val is not None else None + info = { - "thinker_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float), - "thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float), + "thinker_embeddings": decode_emb, + "thinker_hidden_states": decode_hid, "thinker_sequences": ( thinker_output.prompt_token_ids + output.token_ids ), # the thinker_sequences is the whole ids "thinker_input_ids": thinker_output.prompt_token_ids, # Provide thinker-side TTS token embeddings for talker projection - "tts_bos_embed": output.multimodal_output["tts_bos_embed"].detach().to(device=device, dtype=torch.float), - "tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float), - "tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float), + "tts_bos_embed": _tts("tts_bos_embed"), + "tts_eos_embed": _tts("tts_eos_embed"), + "tts_pad_embed": _tts("tts_pad_embed"), } prompt_len = _compute_talker_prompt_ids_length(info, device=device)