diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 509886a02b..f58f9a818c 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -21,6 +21,7 @@ from vllm.multimodal.image import convert_image_mode from vllm.utils.argparse_utils import FlexibleArgumentParser +# Import Omni engine from vllm_omni.entrypoints.omni import Omni SEED = 42 @@ -294,11 +295,30 @@ def main(args): else: query_result = query_func() + # Build kwargs with CLI overrides. + # Global params (e.g. --gpu-memory-utilization) apply to all stages; + # per-stage overrides (--stage-N-*) take precedence when specified. + omni_kwargs = { + "stage_configs_path": args.stage_configs_path, + "log_stats": args.log_stats, + "stage_init_timeout": args.stage_init_timeout, + } + + # Add CLI overrides if specified + if args.gpu_memory_utilization is not None: + omni_kwargs["gpu_memory_utilization"] = args.gpu_memory_utilization + if args.tensor_parallel_size is not None: + omni_kwargs["tensor_parallel_size"] = args.tensor_parallel_size + if args.devices is not None: + omni_kwargs["devices"] = args.devices + if args.enforce_eager: + omni_kwargs["enforce_eager"] = args.enforce_eager + if args.trust_remote_code: + omni_kwargs["trust_remote_code"] = args.trust_remote_code + omni_llm = Omni( model=model_name, - stage_configs_path=args.stage_configs_path, - log_stats=args.log_stats, - stage_init_timeout=args.stage_init_timeout, + **omni_kwargs, ) thinker_sampling_params = SamplingParams( @@ -458,6 +478,12 @@ def parse_args(): default="output_audio", help="[Deprecated] Output wav directory (use --output-dir).", ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Output directory for generated files (text and audio).", + ) parser.add_argument( "--num-prompts", type=int, @@ -474,7 +500,38 @@ def parse_args(): "--stage-configs-path", type=str, default=None, - help="Path to a stage configs file.", + help="Path to a stage configs file. If not specified, auto-detected from model.", + ) + # CLI override arguments + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=None, + help="GPU memory utilization for all stages (CLI override). Example: 0.9", + ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=None, + help="Tensor parallel size for all stages (CLI override). Example: 2", + ) + parser.add_argument( + "--devices", + type=str, + default=None, + help="Device assignment for stages (CLI override). Example: '0,1'", + ) + parser.add_argument( + "--enforce-eager", + action="store_true", + default=False, + help="Enforce eager mode for all stages (CLI override).", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + default=False, + help="Trust remote code for model loading (CLI override).", ) parser.add_argument( "--video-path", diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py new file mode 100644 index 0000000000..26edfb43c7 --- /dev/null +++ b/tests/test_config_factory.py @@ -0,0 +1,477 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for StageConfigFactory and related classes. +""" + +import pytest + +from vllm_omni.config.stage_config import ( + ModelPipeline, + StageConfig, + StageConfigFactory, + StageType, +) + + +class TestStageType: + """Tests for StageType enum.""" + + def test_stage_type_values(self): + """Test StageType enum values.""" + assert StageType.LLM.value == "llm" + assert StageType.DIFFUSION.value == "diffusion" + + def test_stage_type_from_string(self): + """Test creating StageType from string.""" + assert StageType("llm") == StageType.LLM + assert StageType("diffusion") == StageType.DIFFUSION + + +class TestStageConfig: + """Tests for StageConfig dataclass.""" + + def test_minimal_config(self): + """Test creating StageConfig with minimal required fields.""" + config = StageConfig(stage_id=0, model_stage="thinker") + assert config.stage_id == 0 + assert config.model_stage == "thinker" + assert config.stage_type == StageType.LLM + assert config.input_sources == [] + assert config.final_output is False + assert config.worker_type is None + + def test_full_config(self): + """Test creating StageConfig with all fields.""" + config = StageConfig( + stage_id=1, + model_stage="talker", + stage_type=StageType.LLM, + input_sources=[0], + custom_process_input_func="module.path.func", + final_output=True, + final_output_type="audio", + worker_type="ar", + scheduler_cls="path.to.Scheduler", + hf_config_name="talker_config", + is_comprehension=False, + ) + assert config.stage_id == 1 + assert config.model_stage == "talker" + assert config.input_sources == [0] + assert config.final_output_type == "audio" + assert config.worker_type == "ar" + + def test_to_omegaconf_basic(self): + """Test converting StageConfig to OmegaConf format.""" + config = StageConfig( + stage_id=0, + model_stage="thinker", + stage_type=StageType.LLM, + worker_type="ar", + final_output=True, + final_output_type="text", + ) + omega_config = config.to_omegaconf() + + assert omega_config.stage_id == 0 + assert omega_config.stage_type == "llm" + assert omega_config.engine_args.model_stage == "thinker" + assert omega_config.engine_args.worker_type == "ar" + assert omega_config.final_output is True + assert omega_config.final_output_type == "text" + # Legacy field name for backward compatibility + assert omega_config.engine_input_source == [] + + def test_to_omegaconf_with_runtime_overrides(self): + """Test that runtime overrides are applied to OmegaConf output.""" + config = StageConfig( + stage_id=0, + model_stage="thinker", + runtime_overrides={ + "gpu_memory_utilization": 0.9, + "tensor_parallel_size": 2, + "devices": "0,1", + "max_batch_size": 64, + }, + ) + omega_config = config.to_omegaconf() + + assert omega_config.engine_args.gpu_memory_utilization == 0.9 + assert omega_config.engine_args.tensor_parallel_size == 2 + assert omega_config.runtime.devices == "0,1" + assert omega_config.runtime.max_batch_size == 64 + + +class TestModelPipeline: + """Tests for ModelPipeline class.""" + + def test_valid_linear_dag(self): + """Test validation of a valid linear DAG.""" + stages = [ + StageConfig(stage_id=0, model_stage="thinker", input_sources=[]), + StageConfig(stage_id=1, model_stage="talker", input_sources=[0]), + StageConfig(stage_id=2, model_stage="code2wav", input_sources=[1]), + ] + pipeline = ModelPipeline(model_type="test", stages=stages) + errors = pipeline.validate_pipeline() + assert errors == [], f"Unexpected errors: {errors}" + + def test_valid_branching_dag(self): + """Test validation of a valid branching DAG.""" + stages = [ + StageConfig(stage_id=0, model_stage="input", input_sources=[]), + StageConfig(stage_id=1, model_stage="branch_a", input_sources=[0]), + StageConfig(stage_id=2, model_stage="branch_b", input_sources=[0]), + ] + pipeline = ModelPipeline(model_type="test", stages=stages) + errors = pipeline.validate_pipeline() + assert errors == [], f"Unexpected errors: {errors}" + + def test_missing_entry_point(self): + """Test that missing entry point is detected.""" + stages = [ + StageConfig(stage_id=0, model_stage="stage_a", input_sources=[1]), + StageConfig(stage_id=1, model_stage="stage_b", input_sources=[0]), + ] + pipeline = ModelPipeline(model_type="test", stages=stages) + errors = pipeline.validate_pipeline() + assert any("entry point" in e.lower() for e in errors) + + def test_missing_dependency(self): + """Test that missing stage reference is detected.""" + stages = [ + StageConfig(stage_id=0, model_stage="input", input_sources=[]), + StageConfig(stage_id=1, model_stage="output", input_sources=[99]), # Invalid + ] + pipeline = ModelPipeline(model_type="test", stages=stages) + errors = pipeline.validate_pipeline() + assert any("non-existent" in e.lower() for e in errors) + + def test_duplicate_stage_ids(self): + """Test that duplicate stage IDs are detected.""" + stages = [ + StageConfig(stage_id=0, model_stage="stage_a", input_sources=[]), + StageConfig(stage_id=0, model_stage="stage_b", input_sources=[]), # Duplicate + ] + pipeline = ModelPipeline(model_type="test", stages=stages) + errors = pipeline.validate_pipeline() + assert any("duplicate" in e.lower() for e in errors) + + def test_self_reference(self): + """Test that self-references are detected.""" + stages = [ + StageConfig(stage_id=0, model_stage="entry", input_sources=[]), + StageConfig(stage_id=1, model_stage="self_ref", input_sources=[1]), # Self + ] + pipeline = ModelPipeline(model_type="test", stages=stages) + errors = pipeline.validate_pipeline() + assert any("itself" in e.lower() for e in errors) + + def test_get_stage_by_id(self): + """Test getting stage by ID.""" + stages = [ + StageConfig(stage_id=0, model_stage="thinker", input_sources=[]), + StageConfig(stage_id=1, model_stage="talker", input_sources=[0]), + ] + pipeline = ModelPipeline(model_type="test", stages=stages) + + stage = pipeline.get_stage(1) + assert stage is not None + assert stage.model_stage == "talker" + + missing = pipeline.get_stage(99) + assert missing is None + + def test_empty_pipeline(self): + """Test validation of empty pipeline.""" + pipeline = ModelPipeline(model_type="test", stages=[]) + errors = pipeline.validate_pipeline() + assert any("no stages" in e.lower() for e in errors) + + +class TestStageConfigFactory: + """Tests for StageConfigFactory class.""" + + def test_default_diffusion_no_yaml(self): + """Test single-stage diffusion works without YAML config (@ZJY0516).""" + kwargs = { + "cache_backend": "none", + "cache_config": None, + "dtype": "bfloat16", + } + configs = StageConfigFactory.create_default_diffusion(kwargs) + + assert len(configs) == 1 + cfg = configs[0] + assert cfg["stage_id"] == 0 + assert cfg["stage_type"] == "diffusion" + assert cfg["final_output"] is True + assert cfg["final_output_type"] == "image" + + def test_default_diffusion_with_parallel_config(self): + """Test diffusion config calculates devices from parallel_config.""" + + class MockParallelConfig: + world_size = 4 + + kwargs = { + "parallel_config": MockParallelConfig(), + "cache_backend": "tea_cache", + } + configs = StageConfigFactory.create_default_diffusion(kwargs) + + assert configs[0]["runtime"]["devices"] == "0,1,2,3" + + def test_per_stage_override_precedence(self): + """Test that --stage-0-gpu-memory-utilization overrides global.""" + stage = StageConfig(stage_id=0, model_stage="thinker", input_sources=[]) + cli_overrides = { + "gpu_memory_utilization": 0.5, # Global + "stage_0_gpu_memory_utilization": 0.9, # Per-stage override + } + + overrides = StageConfigFactory._merge_cli_overrides(stage, cli_overrides) + + # Per-stage should override global + assert overrides["gpu_memory_utilization"] == 0.9 + + def test_cli_override_forwards_engine_registered_args(self): + """Test that any engine-registered CLI arg is forwarded (@wuhang2014).""" + stage = StageConfig(stage_id=0, model_stage="thinker", input_sources=[]) + cli_overrides = { + "gpu_memory_utilization": 0.9, # Well-known param + "custom_engine_flag": True, # Not in _INTERNAL_KEYS, so forwarded + } + + overrides = StageConfigFactory._merge_cli_overrides(stage, cli_overrides) + + assert overrides["gpu_memory_utilization"] == 0.9 + assert overrides["custom_engine_flag"] is True + + def test_cli_override_excludes_internal_keys(self): + """Test that internal/orchestrator keys are not forwarded.""" + stage = StageConfig(stage_id=0, model_stage="thinker", input_sources=[]) + cli_overrides = { + "gpu_memory_utilization": 0.9, + "model": "some_model", # Internal + "stage_configs_path": "/path", # Internal + "batch_timeout": 10, # Internal + } + + overrides = StageConfigFactory._merge_cli_overrides(stage, cli_overrides) + + assert overrides["gpu_memory_utilization"] == 0.9 + assert "model" not in overrides + assert "stage_configs_path" not in overrides + assert "batch_timeout" not in overrides + + def test_per_stage_override_excludes_internal_keys(self): + """Test that per-stage overrides also skip internal keys.""" + stage = StageConfig(stage_id=0, model_stage="thinker", input_sources=[]) + cli_overrides = { + "stage_0_gpu_memory_utilization": 0.9, + "stage_0_model": "override_model", # Internal, should be skipped + "stage_0_batch_timeout": 5, # Internal, should be skipped + } + + overrides = StageConfigFactory._merge_cli_overrides(stage, cli_overrides) + + assert overrides["gpu_memory_utilization"] == 0.9 + assert "model" not in overrides + assert "batch_timeout" not in overrides + + def test_all_pipeline_files_exist(self): + """Test that every entry in PIPELINE_MODELS has an actual YAML file.""" + from vllm_omni.model_pipelines import get_pipeline_path + + for model_type, dirname in StageConfigFactory.PIPELINE_MODELS.items(): + path = get_pipeline_path(dirname, "pipeline.yaml") + assert path.exists(), f"Missing pipeline file for {model_type}: {path}" + + @pytest.mark.parametrize("model_type", list(StageConfigFactory.PIPELINE_MODELS.keys())) + def test_parse_real_pipeline_files(self, model_type): + """Test that each shipped pipeline YAML parses and validates correctly.""" + from vllm_omni.model_pipelines import get_pipeline_path + + dirname = StageConfigFactory.PIPELINE_MODELS[model_type] + path = get_pipeline_path(dirname, "pipeline.yaml") + pipeline = StageConfigFactory._parse_pipeline_yaml(path, model_type) + + # Basic structure + assert pipeline.model_type == model_type + assert len(pipeline.stages) >= 1 + + # Must pass validation + errors = pipeline.validate_pipeline() + assert errors == [], f"{model_type}: {errors}" + + # Every stage must have required fields + for stage in pipeline.stages: + assert isinstance(stage.stage_id, int) + assert isinstance(stage.model_stage, str) + assert isinstance(stage.stage_type, StageType) + + +class TestPipelineYamlParsing: + """Tests for pipeline YAML file parsing (@ZJY0516).""" + + def test_parse_qwen3_omni_moe_yaml(self, tmp_path): + """Test parsing the qwen3_omni_moe pipeline YAML.""" + yaml_content = """\ +model_type: qwen3_omni_moe + +stages: + - stage_id: 0 + model_stage: thinker + stage_type: llm + input_sources: [] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + hf_config_name: thinker_config + final_output: true + final_output_type: text + is_comprehension: true + + - stage_id: 1 + model_stage: talker + stage_type: llm + input_sources: [0] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + hf_config_name: talker_config + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + + - stage_id: 2 + model_stage: code2wav + stage_type: llm + input_sources: [1] + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + hf_config_name: thinker_config + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio +""" + yaml_file = tmp_path / "qwen3_omni_moe.yaml" + yaml_file.write_text(yaml_content) + + pipeline = StageConfigFactory._parse_pipeline_yaml(yaml_file, "qwen3_omni_moe") + + assert pipeline.model_type == "qwen3_omni_moe" + assert len(pipeline.stages) == 3 + + # Stage 0: thinker + s0 = pipeline.stages[0] + assert s0.stage_id == 0 + assert s0.model_stage == "thinker" + assert s0.stage_type == StageType.LLM + assert s0.input_sources == [] + assert s0.worker_type == "ar" + assert s0.final_output is True + assert s0.final_output_type == "text" + assert s0.is_comprehension is True + + # Stage 1: talker + s1 = pipeline.stages[1] + assert s1.stage_id == 1 + assert s1.input_sources == [0] + assert s1.custom_process_input_func == ( + "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker" + ) + assert s1.final_output is False + + # Stage 2: code2wav + s2 = pipeline.stages[2] + assert s2.stage_id == 2 + assert s2.input_sources == [1] + assert s2.worker_type == "generation" + assert s2.final_output_type == "audio" + + def test_parse_yaml_with_legacy_engine_input_source(self, tmp_path): + """Test backward compatibility with engine_input_source field.""" + yaml_content = """\ +model_type: legacy_model + +stages: + - stage_id: 0 + model_stage: entry + stage_type: llm + - stage_id: 1 + model_stage: downstream + stage_type: llm + engine_input_source: [0] +""" + yaml_file = tmp_path / "legacy.yaml" + yaml_file.write_text(yaml_content) + + pipeline = StageConfigFactory._parse_pipeline_yaml(yaml_file, "legacy_model") + assert pipeline.stages[1].input_sources == [0] + + def test_parse_yaml_with_connectors_and_edges(self, tmp_path): + """Test parsing pipeline with optional connectors and edges.""" + yaml_content = """\ +model_type: test_model + +stages: + - stage_id: 0 + model_stage: entry + stage_type: llm + input_sources: [] + +connectors: + type: ray + +edges: + - from: 0 + to: 1 +""" + yaml_file = tmp_path / "with_connectors.yaml" + yaml_file.write_text(yaml_content) + + pipeline = StageConfigFactory._parse_pipeline_yaml(yaml_file, "test_model") + assert pipeline.connectors == {"type": "ray"} + assert pipeline.edges == [{"from": 0, "to": 1}] + + def test_parsed_pipeline_passes_validation(self, tmp_path): + """Test that a well-formed YAML produces a valid pipeline.""" + yaml_content = """\ +model_type: valid_model + +stages: + - stage_id: 0 + model_stage: entry + stage_type: llm + input_sources: [] + final_output: true + final_output_type: text + - stage_id: 1 + model_stage: next + stage_type: llm + input_sources: [0] +""" + yaml_file = tmp_path / "valid.yaml" + yaml_file.write_text(yaml_content) + + pipeline = StageConfigFactory._parse_pipeline_yaml(yaml_file, "valid_model") + errors = pipeline.validate_pipeline() + assert errors == [], f"Unexpected validation errors: {errors}" + + def test_parse_diffusion_stage_type(self, tmp_path): + """Test parsing a diffusion stage type from YAML.""" + yaml_content = """\ +model_type: diff_model + +stages: + - stage_id: 0 + model_stage: dit + stage_type: diffusion + input_sources: [] + final_output: true + final_output_type: image +""" + yaml_file = tmp_path / "diffusion.yaml" + yaml_file.write_text(yaml_content) + + pipeline = StageConfigFactory._parse_pipeline_yaml(yaml_file, "diff_model") + assert pipeline.stages[0].stage_type == StageType.DIFFUSION diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py index e2db6f4273..2aa236e69f 100644 --- a/vllm_omni/config/__init__.py +++ b/vllm_omni/config/__init__.py @@ -4,8 +4,28 @@ from vllm_omni.config.lora import LoRAConfig from vllm_omni.config.model import OmniModelConfig +from vllm_omni.config.stage_config import ( + ModelPipeline, + StageConfig, + StageConfigFactory, + StageType, +) +from vllm_omni.config.yaml_util import ( + create_config, + load_yaml_config, + merge_configs, + to_dict, +) __all__ = [ "OmniModelConfig", "LoRAConfig", + "StageConfig", + "StageConfigFactory", + "ModelPipeline", + "StageType", + "create_config", + "load_yaml_config", + "merge_configs", + "to_dict", ] diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py new file mode 100644 index 0000000000..473e776d6a --- /dev/null +++ b/vllm_omni/config/stage_config.py @@ -0,0 +1,443 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Stage Configuration System for vLLM-Omni. + +Pipeline structure (stages, types, data-flow) is defined in per-model YAML +files and is set by model developers at integration time. +Runtime parameters (gpu_memory_utilization, tp_size, etc.) come from CLI. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any + +from vllm.logger import init_logger + +from vllm_omni.config.yaml_util import create_config, load_yaml_config, to_dict +from vllm_omni.model_pipelines import get_pipeline_path + +logger = init_logger(__name__) + + +class StageType(str, Enum): + """Type of processing stage in the Omni pipeline.""" + + LLM = "llm" + DIFFUSION = "diffusion" + + +@dataclass +class StageConfig: + """Per-stage configuration — pipeline-structure fields only. + + Engine params (gpu_memory_utilization, tp_size, etc.) come from CLI, + NOT from this class. + """ + + # Identity + stage_id: int + model_stage: str + + # Stage type + stage_type: StageType = StageType.LLM + + input_sources: list[int] = field(default_factory=list) + custom_process_input_func: str | None = None + final_output: bool = False + final_output_type: str | None = None # "text", "audio", "image" + worker_type: str | None = None # "ar" or "generation" + scheduler_cls: str | None = None + hf_config_name: str | None = None + is_comprehension: bool = False + + # Runtime overrides (populated from CLI, not from pipeline YAML) + runtime_overrides: dict[str, Any] = field(default_factory=dict) + + def to_omegaconf(self) -> Any: + """Convert to OmegaConf for backward compatibility with OmniStage. + + Returns: + OmegaConf DictConfig with stage configuration in legacy format. + """ + # Build engine_args dict with required fields + engine_args: dict[str, Any] = { + "model_stage": self.model_stage, + } + + if self.worker_type: + engine_args["worker_type"] = self.worker_type + if self.scheduler_cls: + engine_args["scheduler_cls"] = self.scheduler_cls + if self.hf_config_name: + engine_args["hf_config_name"] = self.hf_config_name + + # Apply runtime overrides (CLI args) + for key, value in self.runtime_overrides.items(): + if key not in ("devices", "max_batch_size"): + engine_args[key] = value + + # Build runtime config + runtime: dict[str, Any] = { + "process": True, + "max_batch_size": self.runtime_overrides.get("max_batch_size", 1), + } + if "devices" in self.runtime_overrides: + runtime["devices"] = self.runtime_overrides["devices"] + + # Build full config dict + config_dict: dict[str, Any] = { + "stage_id": self.stage_id, + "stage_type": StageType(self.stage_type).value, + "engine_args": create_config(engine_args), + "runtime": create_config(runtime), + "engine_input_source": self.input_sources, # Legacy field name + "final_output": self.final_output, + "final_output_type": self.final_output_type, + "is_comprehension": self.is_comprehension, + } + + if self.custom_process_input_func: + config_dict["custom_process_input_func"] = self.custom_process_input_func + + return create_config(config_dict) + + +@dataclass +class ModelPipeline: + """Complete pipeline definition for a multi-stage model. + + Defined by model developers, bundled with the model, not user-editable. + """ + + model_type: str + stages: list[StageConfig] + + # Optional distributed configuration + connectors: dict[str, Any] | None = None + edges: list[dict[str, Any]] | None = None + + def get_stage(self, stage_id: int) -> StageConfig | None: + """Look up a stage by its ID. + + Args: + stage_id: The stage ID to search for. + + Returns: + The matching StageConfig, or None if not found. + """ + for stage in self.stages: + if stage.stage_id == stage_id: + return stage + return None + + def validate_pipeline(self) -> list[str]: + """Validate pipeline topology at model integration time (not runtime). + + Checks: + - All stage IDs are unique + - All input_sources reference valid stage IDs + - At least one entry point (stage with empty input_sources) + + Returns: + List of validation error messages. Empty list if valid. + """ + errors: list[str] = [] + + if not self.stages: + errors.append("Topology has no stages defined") + return errors + + # Check for unique stage IDs + stage_ids = [s.stage_id for s in self.stages] + if len(stage_ids) != len(set(stage_ids)): + errors.append("Duplicate stage IDs found") + + stage_id_set = set(stage_ids) + + # Check input_sources reference valid stages + for stage in self.stages: + for source_id in stage.input_sources: + if source_id not in stage_id_set: + errors.append(f"Stage {stage.stage_id} references non-existent input source {source_id}") + if source_id == stage.stage_id: + errors.append(f"Stage {stage.stage_id} references itself as input source") + + # Check for at least one entry point + entry_points = [s for s in self.stages if not s.input_sources] + if not entry_points: + errors.append("No entry point found (stage with empty input_sources)") + + return errors + + +class StageConfigFactory: + """Factory that loads pipeline YAML and merges CLI overrides. + + Handles both single-stage and multi-stage models. + """ + + # Mapping of model types to pipeline directories under model_pipelines/. + PIPELINE_MODELS: dict[str, str] = { + "qwen3_omni_moe": "qwen3_omni_moe", + "qwen2_5_omni": "qwen2_5_omni", + "bagel": "bagel", + "qwen3_tts": "qwen3_tts", + } + + @classmethod + def create_from_model( + cls, + model: str, + cli_overrides: dict[str, Any] | None = None, + ) -> list[StageConfig] | None: + """Load pipeline YAML, merge with CLI overrides. + + Args: + model: Model name or path. + cli_overrides: CLI overrides from VllmConfig/OmniDiffusionConfig. + + Returns: + List of StageConfig objects with CLI overrides applied, + or None if no pipeline definition was found for this model. + """ + if cli_overrides is None: + cli_overrides = {} + + trust_remote_code = cli_overrides.get("trust_remote_code", True) + pipeline = cls._load_pipeline(model, trust_remote_code=trust_remote_code) + + if pipeline is None: + return None + + errors = pipeline.validate_pipeline() + if errors: + logger.warning(f"Pipeline validation warnings for {model}: {errors}") + + # Apply CLI overrides + result: list[StageConfig] = [] + for stage in pipeline.stages: + # Merge global CLI overrides + stage.runtime_overrides = cls._merge_cli_overrides(stage, cli_overrides) + result.append(stage) + + return result + + @classmethod + def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]]: + """Single-stage diffusion - no YAML needed. + + Creates a default diffusion stage configuration for single-stage + diffusion models. Returns a legacy OmegaConf-compatible dict for + backward compatibility with OmniStage. + + Args: + kwargs: Engine arguments from CLI/API. + + Returns: + List containing a single config dict for the diffusion stage. + """ + # Calculate devices based on parallel config + devices = "0" + if "parallel_config" in kwargs: + num_devices = kwargs["parallel_config"].world_size + for i in range(1, num_devices): + devices += f",{i}" + + # Collect engine args – skip non-serializable objects + engine_args: dict[str, Any] = {} + for key, value in kwargs.items(): + if key in ("parallel_config",): + continue + engine_args[key] = value + + engine_args.setdefault("cache_backend", "none") + engine_args["model_stage"] = "diffusion" + + # Convert dtype to string for OmegaConf + if "dtype" in engine_args: + engine_args["dtype"] = str(engine_args["dtype"]) + + config_dict: dict[str, Any] = { + "stage_id": 0, + "stage_type": StageType.DIFFUSION.value, + "runtime": { + "process": True, + "devices": devices, + "max_batch_size": 1, + }, + "engine_args": create_config(engine_args), + "final_output": True, + "final_output_type": "image", + } + + return [config_dict] + + @classmethod + def _load_pipeline(cls, model: str, trust_remote_code: bool = True) -> ModelPipeline | None: + """Load pipeline YAML for the model. + + Args: + model: Model name or path. + trust_remote_code: Whether to trust remote code for HF config loading. + + Returns: + ModelPipeline if found, None otherwise. + """ + model_type = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code) + if model_type is None: + return None + + pipeline_dir = cls.PIPELINE_MODELS.get(model_type) + if pipeline_dir is None: + logger.debug(f"No pipeline mapping for model_type: {model_type}") + return None + + pipeline_path = get_pipeline_path(pipeline_dir, "pipeline.yaml") + + if not pipeline_path.exists(): + logger.debug(f"Pipeline file not found: {pipeline_path}") + return None + + return cls._parse_pipeline_yaml(pipeline_path, model_type) + + @classmethod + def _parse_pipeline_yaml(cls, path: Path, model_type: str) -> ModelPipeline: + """Parse a pipeline YAML file. + + Args: + path: Path to the YAML file. + model_type: Model type identifier. + + Returns: + ModelPipeline object. + """ + config_data = load_yaml_config(path) + + stages: list[StageConfig] = [] + for stage_data in config_data.stages: + # Use .get() for optional fields — idiomatic for OmegaConf DictConfig + stage_type_str = stage_data.get("stage_type", "llm") + stage_type = StageType(stage_type_str) if stage_type_str else StageType.LLM + + # Handle both 'input_sources' (new) and 'engine_input_source' (legacy) + input_sources = stage_data.get("input_sources", None) + if input_sources is None: + input_sources = stage_data.get("engine_input_source", []) + if input_sources is None: + input_sources = [] + input_sources = list(input_sources) + + stage = StageConfig( + stage_id=stage_data.stage_id, + model_stage=stage_data.model_stage, + stage_type=stage_type, + input_sources=input_sources, + custom_process_input_func=stage_data.get("custom_process_input_func", None), + final_output=stage_data.get("final_output", False), + final_output_type=stage_data.get("final_output_type", None), + worker_type=stage_data.get("worker_type", None), + scheduler_cls=stage_data.get("scheduler_cls", None), + hf_config_name=stage_data.get("hf_config_name", None), + is_comprehension=stage_data.get("is_comprehension", False), + ) + stages.append(stage) + + # Get optional connector config + connectors = to_dict(config_data.connectors) if hasattr(config_data, "connectors") else None + edges = to_dict(config_data.edges) if hasattr(config_data, "edges") else None + + return ModelPipeline( + model_type=getattr(config_data, "model_type", model_type), + stages=stages, + connectors=connectors, + edges=edges, + ) + + @classmethod + def _auto_detect_model_type(cls, model: str, trust_remote_code: bool = True) -> str | None: + """Auto-detect model_type from model directory. + + Args: + model: Model name or path. + trust_remote_code: Whether to trust remote code for HF config loading. + + Returns: + Model type string if detected, None otherwise. + """ + try: + from vllm.transformers_utils.config import get_config + + hf_config = get_config(model, trust_remote_code=trust_remote_code) + return hf_config.model_type + except Exception as e: + logger.debug(f"Failed to auto-detect model type for {model}: {e}") + return None + + # Keys that should never be forwarded as engine overrides (internal / + # orchestrator-only knobs, complex objects, etc.). + _INTERNAL_KEYS: set[str] = { + "model", + "stage_configs_path", + "stage_id", + "stage_init_timeout", + "init_timeout", + "shm_threshold_bytes", + "worker_backend", + "ray_address", + "batch_timeout", + "log_stats", + "tokenizer", + "parallel_config", + } + + @classmethod + def _merge_cli_overrides( + cls, + stage: StageConfig, + cli_overrides: dict[str, Any], + ) -> dict[str, Any]: + """Merge CLI overrides into stage runtime config. + + All CLI arguments registered by engine config classes (e.g. + EngineArgs / OmniDiffusionConfig) are accepted as overrides + unless they appear in ``_INTERNAL_KEYS``. + + Handles: + - Global overrides (apply to all stages) + - Per-stage overrides (--stage-N-* format, take precedence) + + Args: + stage: The stage to merge overrides into. + cli_overrides: CLI arguments from VllmConfig/OmniDiffusionConfig. + + Returns: + Dict of runtime overrides for this stage. + """ + result: dict[str, Any] = {} + + # Apply global overrides – any key not in the internal blocklist + # is forwarded so that engine-registered params work out of the box. + for key, value in cli_overrides.items(): + if key in cls._INTERNAL_KEYS: + continue + if re.match(r"stage_\d+_", key): + # Per-stage keys handled below + continue + if value is not None: + result[key] = value + + # Apply per-stage overrides (--stage-N-* format, take precedence) + stage_prefix = f"stage_{stage.stage_id}_" + for key, value in cli_overrides.items(): + if key.startswith(stage_prefix) and value is not None: + param_name = key[len(stage_prefix) :] + if param_name in cls._INTERNAL_KEYS: + continue + result[param_name] = value + + return result diff --git a/vllm_omni/config/yaml_util.py b/vllm_omni/config/yaml_util.py new file mode 100644 index 0000000000..09b27ca022 --- /dev/null +++ b/vllm_omni/config/yaml_util.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Centralized OmegaConf wrapper for vLLM-Omni. + +All OmegaConf usage in the project MUST go through this module. +Other modules should import these wrapper functions instead of +using OmegaConf directly. +""" + +from __future__ import annotations + +from typing import Any + +from omegaconf import DictConfig, OmegaConf + + +def load_yaml_config(path: str | Any) -> DictConfig: + """Load a YAML file and return it as a DictConfig. + + Args: + path: Path to the YAML file. + + Returns: + OmegaConf DictConfig with attribute-style access. + """ + return OmegaConf.load(path) # type: ignore[return-value] + + +def create_config(data: Any) -> DictConfig: + """Wrap a dict (or list) into a DictConfig. + + Args: + data: Dict, list, or other structure to wrap. + + Returns: + OmegaConf DictConfig / ListConfig. + """ + return OmegaConf.create(data) + + +def merge_configs(*cfgs: Any) -> dict: + """Deep-merge multiple configs and return a plain dict. + + Args: + *cfgs: DictConfig or dict objects to merge (left to right). + + Returns: + Plain dict with merged, resolved values. + """ + merged = OmegaConf.merge(*cfgs) + return OmegaConf.to_container(merged, resolve=True) # type: ignore[return-value] + + +def to_dict(obj: Any, *, resolve: bool = True) -> Any: + """Convert a DictConfig (or similar) to a plain dict. + + Args: + obj: OmegaConf container to convert. + resolve: Whether to resolve interpolations (default True). + + Returns: + Plain dict. + """ + return OmegaConf.to_container(obj, resolve=resolve) # type: ignore[return-value] diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 5880ec5948..1327566fc4 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -14,13 +14,14 @@ import huggingface_hub import msgspec.msgpack import zmq -from omegaconf import OmegaConf from tqdm.auto import tqdm from vllm import SamplingParams from vllm.logger import init_logger from vllm.utils.network_utils import make_zmq_socket from vllm.v1.utils import get_engine_client_zmq_addr +from vllm_omni.config.stage_config import StageConfigFactory +from vllm_omni.config.yaml_util import create_config from vllm_omni.distributed.omni_connectors import ( get_stage_connector_config, initialize_orchestrator_connectors, @@ -211,43 +212,29 @@ def _normalize_cache_config(self, cache_backend: str | None, cache_config: Any | cache_config = self._get_default_cache_config(cache_backend) return cache_config - def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[str, Any]: - """Create default diffusion stage configuration.""" - # We temporally create a default config for diffusion stage. - # In the future, we should merge the default config with the user-provided config. - # TODO: hack, convert dtype to string to avoid non-premitive omegaconf create error. - if "dtype" in kwargs: - kwargs["dtype"] = str(kwargs["dtype"]) + def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> list[dict[str, Any]]: + """Create default diffusion stage configuration. + + Uses StageConfigFactory for typed configuration creation while + maintaining backward compatibility with the legacy format. + + Args: + kwargs: Engine arguments from CLI/API. + + Returns: + List containing a single OmegaConf config for the diffusion stage. + """ + # Normalize cache config before passing to factory cache_backend = kwargs.get("cache_backend", "none") cache_config = self._normalize_cache_config(cache_backend, kwargs.get("cache_config", None)) - # TODO: hack, calculate devices based on parallel config. - devices = "0" - if "parallel_config" in kwargs: - num_devices = kwargs["parallel_config"].world_size - for i in range(1, num_devices): - devices += f",{i}" - default_stage_cfg = [ - { - "stage_id": 0, - "stage_type": "diffusion", - "runtime": { - "process": True, - "devices": devices, - "max_batch_size": 1, - }, - "engine_args": OmegaConf.create( - { - **kwargs, - "cache_backend": cache_backend, - "cache_config": cache_config, - } - ), - "final_output": True, - "final_output_type": "image", - } - ] - default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" - return default_stage_cfg + + # Update kwargs with normalized values + kwargs_copy = dict(kwargs) + kwargs_copy["cache_backend"] = cache_backend + kwargs_copy["cache_config"] = cache_config + + # Use the factory to create default diffusion config + return StageConfigFactory.create_default_diffusion(kwargs_copy) def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[str, list[Any]]: """Resolve stage configs and inject defaults shared by orchestrator/headless.""" @@ -274,7 +261,7 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st if getattr(cfg, "stage_type", None) != "diffusion": continue if not hasattr(cfg, "engine_args") or cfg.engine_args is None: - cfg.engine_args = OmegaConf.create({}) + cfg.engine_args = create_config({}) if kwargs.get("lora_path") is not None: if not hasattr(cfg.engine_args, "lora_path") or cfg.engine_args.lora_path is None: cfg.engine_args.lora_path = kwargs["lora_path"] diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 2a151f18f1..1b0dbad0e3 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -264,7 +264,12 @@ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): self.engine_args = stage_config.engine_args self.model_stage = stage_config.engine_args.model_stage self.requires_multimodal_data = getattr(stage_config.runtime, "requires_multimodal_data", False) - self.engine_input_source = getattr(stage_config, "engine_input_source", []) + # Support both 'input_sources' (new format) and 'engine_input_source' (legacy) + self.engine_input_source = getattr(stage_config, "input_sources", None) + if self.engine_input_source is None: + self.engine_input_source = getattr(stage_config, "engine_input_source", []) + if self.engine_input_source is None: + self.engine_input_source = [] self.engine_output_type = getattr(stage_config.engine_args, "engine_output_type", None) self.engine_outputs = None self.is_comprehension = getattr(stage_config, "is_comprehension", False) diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index 5eeb1ab0fa..9dec144a8c 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -7,7 +7,7 @@ from multiprocessing import shared_memory as _shm from typing import Any -from omegaconf import OmegaConf +from vllm_omni.config.yaml_util import to_dict as _omega_to_dict logger = logging.getLogger(__name__) @@ -296,7 +296,7 @@ def _to_dict(x: Any) -> dict[str, Any]: try: if isinstance(x, dict): return dict(x) - return OmegaConf.to_container(x, resolve=True) # type: ignore[arg-type] + return _omega_to_dict(x) except Exception: try: return dict(x) diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 69ce73fc47..0e31bfa7c2 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -5,11 +5,11 @@ from pathlib import Path from typing import Any, get_args, get_origin -from omegaconf import OmegaConf from vllm.logger import init_logger from vllm.transformers_utils.config import get_config, get_hf_file_to_dict from vllm.transformers_utils.repo_utils import file_or_path_exists +from vllm_omni.config.yaml_util import create_config, load_yaml_config, merge_configs from vllm_omni.entrypoints.stage_utils import _to_dict from vllm_omni.platforms import current_omni_platform @@ -234,6 +234,11 @@ def resolve_model_config_path(model: str) -> str: def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list: """Load stage configurations from model's default config file. + .. deprecated:: + This is the legacy OmegaConf-based loading path. New code should use + ``StageConfigFactory.create_from_model()`` instead. This function will + be removed once all callers are migrated (see PR series [2/N]). + Loads stage configurations based on the model type and device type. First tries to load a device-specific YAML file from stage_configs/{device_type}/ directory. If not found, falls back to the default config file. @@ -259,6 +264,9 @@ def load_stage_configs_from_model(model: str, base_engine_args: dict | None = No def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None = None) -> list: """Load stage configurations from a YAML file. + .. deprecated:: + Legacy OmegaConf-based loader. Will be removed in PR series [2/N]. + Args: config_path: Path to the YAML configuration file @@ -267,17 +275,17 @@ def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None """ if base_engine_args is None: base_engine_args = {} - config_data = OmegaConf.load(config_path) + config_data = load_yaml_config(config_path) stage_args = config_data.stage_args global_async_chunk = config_data.get("async_chunk", False) - # Convert any nested dataclass objects to dicts before creating OmegaConf + # Convert any nested dataclass objects to dicts before creating DictConfig base_engine_args = _convert_dataclasses_to_dict(base_engine_args) - base_engine_args = OmegaConf.create(base_engine_args) + base_engine_args = create_config(base_engine_args) for stage_arg in stage_args: base_engine_args_tmp = base_engine_args.copy() # Update base_engine_args with stage-specific engine_args if they exist if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None: - base_engine_args_tmp = OmegaConf.merge(base_engine_args_tmp, stage_arg.engine_args) + base_engine_args_tmp = create_config(merge_configs(base_engine_args_tmp, stage_arg.engine_args)) stage_type = getattr(stage_arg, "stage_type", "llm") if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None and stage_type != "diffusion": runtime_cfg = stage_arg.runtime @@ -312,7 +320,7 @@ def load_and_resolve_stage_configs( if not stage_configs: if default_stage_cfg_factory is not None: default_stage_cfg = default_stage_cfg_factory() - stage_configs = OmegaConf.create(default_stage_cfg) + stage_configs = create_config(default_stage_cfg) else: stage_configs = [] else: @@ -347,6 +355,7 @@ def get_final_stage_id_for_e2e( output_modalities = default_modalities try: + final_stage_id_for_e2e = last_stage_id for _sid in range(last_stage_id, -1, -1): if ( getattr(stage_list[_sid], "final_output", False) @@ -354,8 +363,6 @@ def get_final_stage_id_for_e2e( ): final_stage_id_for_e2e = _sid break - if final_stage_id_for_e2e < 0: - final_stage_id_for_e2e = last_stage_id except Exception as e: logger.debug( "[Orchestrator] Failed to determine final stage for E2E; \ diff --git a/vllm_omni/model_pipelines/__init__.py b/vllm_omni/model_pipelines/__init__.py new file mode 100644 index 0000000000..9d73dd2051 --- /dev/null +++ b/vllm_omni/model_pipelines/__init__.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Model pipeline definitions for vLLM-Omni. + +Each subdirectory contains: +- pipeline.yaml: Pipeline structure (stages, types, data-flow). +- default_args.yaml: Default runtime / engine args per stage. + +Runtime parameters (GPU memory, tensor-parallel size, etc.) can be +overridden via CLI flags. +""" + +from pathlib import Path + +PIPELINES_DIR = Path(__file__).parent + + +def get_pipeline_path(model_dir: str, filename: str) -> Path: + """Return the full path to a pipeline YAML file. + + Args: + model_dir: Model subdirectory name (e.g., "qwen3_omni_moe"). + filename: Name of the YAML file (e.g., "pipeline.yaml"). + + Returns: + Absolute path to the file. + """ + return PIPELINES_DIR / model_dir / filename diff --git a/vllm_omni/model_pipelines/bagel/default_args.yaml b/vllm_omni/model_pipelines/bagel/default_args.yaml new file mode 100644 index 0000000000..bdd0c0cfd6 --- /dev/null +++ b/vllm_omni/model_pipelines/bagel/default_args.yaml @@ -0,0 +1,60 @@ +# Default runtime args for Bagel +# These can be overridden via CLI flags (e.g. --gpu-memory-utilization). + +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: BagelForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.35 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_send_cache: true + kv_transfer_criteria: + type: prefill_finished + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 52 + detokenize: true + repetition_penalty: 1.05 + + - stage_id: 1 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + gpu_memory_utilization: 0.55 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_recv_cache: true + default_sampling_params: + seed: 52 + +# Runtime edges +runtime: + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 # 64KB threshold diff --git a/vllm_omni/model_pipelines/bagel/pipeline.yaml b/vllm_omni/model_pipelines/bagel/pipeline.yaml new file mode 100644 index 0000000000..668b1fafe1 --- /dev/null +++ b/vllm_omni/model_pipelines/bagel/pipeline.yaml @@ -0,0 +1,34 @@ +# Model Pipeline Config for Bagel +# Defines pipeline structure (stages, types, data-flow). +# Runtime params (gpu_memory_utilization, tp_size, etc.) come from CLI. +# +# Stage 0: Thinker (LLM - multimodal understanding + text generation) +# Stage 1: DiT (Diffusion - image generation from KV cache) + +model_type: bagel + +stages: + - stage_id: 0 + model_stage: thinker + stage_type: llm + input_sources: [] # Entry point - no upstream stages + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + final_output: true + final_output_type: text + is_comprehension: true + + - stage_id: 1 + model_stage: dit + stage_type: diffusion + input_sources: [0] # Receives from thinker + final_output: true + final_output_type: image + +connectors: + shared_memory_connector: + name: SharedMemoryConnector + +edges: + - from: 0 + to: 1 diff --git a/vllm_omni/model_pipelines/qwen2_5_omni/default_args.yaml b/vllm_omni/model_pipelines/qwen2_5_omni/default_args.yaml new file mode 100644 index 0000000000..7289c2f94f --- /dev/null +++ b/vllm_omni/model_pipelines/qwen2_5_omni/default_args.yaml @@ -0,0 +1,76 @@ +# Default runtime args for Qwen2.5-Omni +# These can be overridden via CLI flags (e.g. --gpu-memory-utilization). + +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + max_num_batched_tokens: 32768 + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: true + repetition_penalty: 1.1 + + - stage_id: 1 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + engine_output_type: latent + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + detokenize: true + repetition_penalty: 1.05 + stop_token_ids: [8294] + + - stage_id: 2 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + engine_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: true + repetition_penalty: 1.1 diff --git a/vllm_omni/model_pipelines/qwen2_5_omni/pipeline.yaml b/vllm_omni/model_pipelines/qwen2_5_omni/pipeline.yaml new file mode 100644 index 0000000000..6c4393b60b --- /dev/null +++ b/vllm_omni/model_pipelines/qwen2_5_omni/pipeline.yaml @@ -0,0 +1,37 @@ +# Model Pipeline Config for Qwen2.5-Omni +# Defines pipeline structure (stages, types, data-flow). +# Runtime params (gpu_memory_utilization, tp_size, etc.) come from CLI. +# +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings -> audio codec codes) +# Stage 2: Code2Wav (codec codes -> audio waveform) + +model_type: qwen2_5_omni + +stages: + - stage_id: 0 + model_stage: thinker + stage_type: llm + input_sources: [] # Entry point - no upstream stages + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + final_output: true + final_output_type: text + is_comprehension: true + + - stage_id: 1 + model_stage: talker + stage_type: llm + input_sources: [0] # Receives from thinker + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + + - stage_id: 2 + model_stage: code2wav + stage_type: llm + input_sources: [1] # Receives from talker + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + final_output: true + final_output_type: audio diff --git a/vllm_omni/model_pipelines/qwen3_omni_moe/default_args.yaml b/vllm_omni/model_pipelines/qwen3_omni_moe/default_args.yaml new file mode 100644 index 0000000000..9ea7951e3d --- /dev/null +++ b/vllm_omni/model_pipelines/qwen3_omni_moe/default_args.yaml @@ -0,0 +1,82 @@ +# Default runtime args for Qwen3-Omni-MoE +# These can be overridden via CLI flags (e.g. --gpu-memory-utilization). + +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 64 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 1 + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: true + repetition_penalty: 1.05 + + - stage_id: 1 + runtime: + devices: "1" + max_batch_size: 64 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + max_num_batched_tokens: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + runtime: + devices: "1" + max_batch_size: 64 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.1 diff --git a/vllm_omni/model_pipelines/qwen3_omni_moe/pipeline.yaml b/vllm_omni/model_pipelines/qwen3_omni_moe/pipeline.yaml new file mode 100644 index 0000000000..894e0a1ece --- /dev/null +++ b/vllm_omni/model_pipelines/qwen3_omni_moe/pipeline.yaml @@ -0,0 +1,41 @@ +# Model Pipeline Config for Qwen3-Omni-MoE +# Defines pipeline structure (stages, types, data-flow). +# Runtime params (gpu_memory_utilization, tp_size, etc.) come from CLI. +# +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings -> 8-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes -> audio waveform) + +model_type: qwen3_omni_moe + +stages: + - stage_id: 0 + model_stage: thinker + stage_type: llm + input_sources: [] # Entry point - no upstream stages + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + hf_config_name: thinker_config + final_output: true + final_output_type: text + is_comprehension: true + + - stage_id: 1 + model_stage: talker + stage_type: llm + input_sources: [0] # Receives from thinker + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + hf_config_name: talker_config + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + + - stage_id: 2 + model_stage: code2wav + stage_type: llm + input_sources: [1] # Receives from talker + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + hf_config_name: thinker_config + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio diff --git a/vllm_omni/model_pipelines/qwen3_tts/default_args.yaml b/vllm_omni/model_pipelines/qwen3_tts/default_args.yaml new file mode 100644 index 0000000000..52e595db12 --- /dev/null +++ b/vllm_omni/model_pipelines/qwen3_tts/default_args.yaml @@ -0,0 +1,72 @@ +# Default runtime args for Qwen3-TTS +# These can be overridden via CLI flags (e.g. --gpu-memory-utilization). + +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: qwen3_tts + model_arch: Qwen3TTSTalkerForConditionalGeneration + hf_overrides: + architectures: [Qwen3TTSTalkerForConditionalGeneration] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: false + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.3 + distributed_executor_backend: "mp" + max_num_batched_tokens: 512 + max_model_len: 4096 + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 1 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3TTSCode2Wav + hf_overrides: + architectures: [Qwen3TTSCode2Wav] + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.2 + distributed_executor_backend: "mp" + max_num_batched_tokens: 8192 + max_model_len: 32768 + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +runtime: + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + codec_streaming: true + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + codec_chunk_frames: 25 + codec_left_context_frames: 25 diff --git a/vllm_omni/model_pipelines/qwen3_tts/pipeline.yaml b/vllm_omni/model_pipelines/qwen3_tts/pipeline.yaml new file mode 100644 index 0000000000..e575ea5984 --- /dev/null +++ b/vllm_omni/model_pipelines/qwen3_tts/pipeline.yaml @@ -0,0 +1,25 @@ +# Model Pipeline Config for Qwen3-TTS +# Defines pipeline structure (stages, types, data-flow). +# Runtime params (gpu_memory_utilization, tp_size, etc.) come from CLI. +# +# Stage 0: Qwen3-TTS (text -> audio codec codes) +# Stage 1: Code2Wav (codec codes -> audio waveform) + +model_type: qwen3_tts + +stages: + - stage_id: 0 + model_stage: qwen3_tts + stage_type: llm + input_sources: [] # Entry point - no upstream stages + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + + - stage_id: 1 + model_stage: code2wav + stage_type: llm + input_sources: [0] # Receives from talker + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + final_output: true + final_output_type: audio