Skip to content

Commit 7176385

Browse files
lishunyang12claude
andcommitted
Address review comments: remove dead code and stage-id launch
- Remove ARCH_MAPPING, _auto_detect_model_arch, RUNTIME_PARAMS (dead code) - Remove --stage-id / stage_id_filter (conflicts with #939) - Remove duplicate test, fix stale comments - Add global params clarifying comment in end2end.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1aa4640 commit 7176385

File tree

6 files changed

+8
-114
lines changed

6 files changed

+8
-114
lines changed

examples/offline_inference/qwen3_omni/end2end.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,9 @@ def main(args):
295295
else:
296296
query_result = query_func()
297297

298-
# Build kwargs with Tier-2 CLI overrides
298+
# Build kwargs with Tier-2 CLI overrides.
299+
# Global params (e.g. --gpu-memory-utilization) apply to all stages;
300+
# per-stage overrides (--stage-N-*) take precedence when specified.
299301
omni_kwargs = {
300302
"stage_configs_path": args.stage_configs_path,
301303
"log_stats": args.log_stats,
@@ -313,8 +315,6 @@ def main(args):
313315
omni_kwargs["enforce_eager"] = args.enforce_eager
314316
if args.trust_remote_code:
315317
omni_kwargs["trust_remote_code"] = args.trust_remote_code
316-
if args.stage_id is not None:
317-
omni_kwargs["stage_id"] = args.stage_id
318318

319319
omni_llm = Omni(
320320
model=model_name,
@@ -533,12 +533,6 @@ def parse_args():
533533
default=False,
534534
help="Trust remote code for model loading (Tier-2 override).",
535535
)
536-
parser.add_argument(
537-
"--stage-id",
538-
type=int,
539-
default=None,
540-
help="Launch only the specified stage ID for independent stage testing.",
541-
)
542536
parser.add_argument(
543537
"--video-path",
544538
"-v",

tests/test_config_factory.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,6 @@ def test_duplicate_stage_ids(self):
158158
errors = topology.validate_topology()
159159
assert any("duplicate" in e.lower() for e in errors)
160160

161-
def test_mutual_dependency_detected_as_missing_entry(self):
162-
"""Test that mutual dependencies are caught (no entry point)."""
163-
stages = [
164-
StageConfig(stage_id=0, model_stage="stage_a", input_sources=[1]),
165-
StageConfig(stage_id=1, model_stage="stage_b", input_sources=[0]),
166-
]
167-
topology = StageTopology(model_type="test", stages=stages)
168-
errors = topology.validate_topology()
169-
assert any("entry point" in e.lower() for e in errors)
170-
171161
def test_self_reference(self):
172162
"""Test that self-references are detected."""
173163
stages = [
@@ -315,7 +305,7 @@ def test_cli_override_forwards_engine_registered_args(self):
315305
stage = StageConfig(stage_id=0, model_stage="thinker", input_sources=[])
316306
cli_overrides = {
317307
"gpu_memory_utilization": 0.9, # Well-known param
318-
"custom_engine_flag": True, # Engine-registered but not in RUNTIME_PARAMS
308+
"custom_engine_flag": True, # Not in _INTERNAL_KEYS, so forwarded
319309
}
320310

321311
overrides = StageConfigFactory._merge_cli_overrides(stage, cli_overrides)
@@ -340,13 +330,6 @@ def test_cli_override_excludes_internal_keys(self):
340330
assert "stage_configs_path" not in overrides
341331
assert "batch_timeout" not in overrides
342332

343-
def test_arch_mapping(self):
344-
"""Test that model architecture mapping is correct."""
345-
assert StageConfigFactory.ARCH_MAPPING["qwen3_omni_moe"] == "Qwen3OmniMoeForConditionalGeneration"
346-
assert StageConfigFactory.ARCH_MAPPING["qwen2_5_omni"] == "Qwen2_5OmniForConditionalGeneration"
347-
assert StageConfigFactory.ARCH_MAPPING["bagel"] == "BagelForConditionalGeneration"
348-
assert StageConfigFactory.ARCH_MAPPING["qwen3_tts"] == "Qwen3TTSTalkerForConditionalGeneration"
349-
350333
def test_all_topology_files_exist(self):
351334
"""Test that every entry in TOPOLOGY_FILES has an actual YAML file."""
352335
from vllm_omni.model_executor.stage_topologies import get_topology_path

vllm_omni/config/stage_config.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ class StageConfigFactory:
223223
This factory is the main entry point for creating stage configurations.
224224
It handles:
225225
- Loading internal Tier-1 pipeline topology files
226-
- Auto-detecting model architecture
227226
- Merging CLI overrides (Tier-2) into stage configs
228227
- Supporting both single-stage and multi-stage models
229228
"""
@@ -238,28 +237,17 @@ class StageConfigFactory:
238237
"qwen3_tts": "qwen3_tts.yaml",
239238
}
240239

241-
# Mapping of model types to architecture classes
242-
ARCH_MAPPING: dict[str, str] = {
243-
"qwen3_omni_moe": "Qwen3OmniMoeForConditionalGeneration",
244-
"qwen2_5_omni": "Qwen2_5OmniForConditionalGeneration",
245-
"bagel": "BagelForConditionalGeneration",
246-
"qwen3_tts": "Qwen3TTSTalkerForConditionalGeneration",
247-
}
248-
249240
@classmethod
250241
def create_from_model(
251242
cls,
252243
model: str,
253244
cli_overrides: dict[str, Any] | None = None,
254-
stage_id_filter: int | None = None,
255245
) -> list[StageConfig]:
256246
"""Load internal topology, merge with CLI overrides.
257247
258248
Args:
259249
model: Model name or path.
260250
cli_overrides: Tier-2 CLI overrides from VllmConfig/OmniDiffusionConfig.
261-
stage_id_filter: If specified, only return the stage with this ID
262-
(for independent stage launch).
263251
264252
Returns:
265253
List of StageConfig objects with CLI overrides applied.
@@ -279,12 +267,9 @@ def create_from_model(
279267
if errors:
280268
logger.warning(f"Topology validation warnings for {model}: {errors}")
281269

282-
# Apply CLI overrides and filter stages
270+
# Apply CLI overrides
283271
result: list[StageConfig] = []
284272
for stage in topology.stages:
285-
if stage_id_filter is not None and stage.stage_id != stage_id_filter:
286-
continue
287-
288273
# Merge global CLI overrides
289274
stage.runtime_overrides = cls._merge_cli_overrides(stage, cli_overrides)
290275
result.append(stage)
@@ -440,46 +425,6 @@ def _auto_detect_model_type(cls, model: str) -> str | None:
440425
logger.debug(f"Failed to auto-detect model type for {model}: {e}")
441426
return None
442427

443-
@classmethod
444-
def _auto_detect_model_arch(cls, model: str) -> str | None:
445-
"""Auto-detect model_arch from model directory.
446-
447-
Args:
448-
model: Model name or path.
449-
450-
Returns:
451-
Model architecture class name if detected, None otherwise.
452-
"""
453-
model_type = cls._auto_detect_model_type(model)
454-
if model_type is None:
455-
return None
456-
457-
# Check mapping first
458-
if model_type in cls.ARCH_MAPPING:
459-
return cls.ARCH_MAPPING[model_type]
460-
461-
# Fallback: generate from model_type
462-
# Convert snake_case to PascalCase and add suffix
463-
parts = model_type.split("_")
464-
pascal_case = "".join(part.capitalize() for part in parts)
465-
return f"{pascal_case}ForConditionalGeneration"
466-
467-
# Well-known Tier-2 runtime parameters. Any CLI arg whose name
468-
# matches one of these keys is forwarded to every stage by default.
469-
# Additional engine-registered args are also accepted (see
470-
# _merge_cli_overrides), so this set does NOT need to be exhaustive.
471-
RUNTIME_PARAMS: set[str] = {
472-
"gpu_memory_utilization",
473-
"tensor_parallel_size",
474-
"devices",
475-
"enforce_eager",
476-
"max_num_batched_tokens",
477-
"trust_remote_code",
478-
"max_batch_size",
479-
"distributed_executor_backend",
480-
"enable_prefix_caching",
481-
}
482-
483428
# Keys that should never be forwarded as engine overrides (internal /
484429
# orchestrator-only knobs, complex objects, etc.).
485430
_INTERNAL_KEYS: set[str] = {
@@ -506,8 +451,8 @@ def _merge_cli_overrides(
506451
"""Merge CLI overrides into stage runtime config.
507452
508453
All CLI arguments registered by engine config classes (e.g.
509-
EngineArgs / OmniDiffusionConfig) are accepted as overrides,
510-
not just the well-known ``RUNTIME_PARAMS`` set.
454+
EngineArgs / OmniDiffusionConfig) are accepted as overrides
455+
unless they appear in ``_INTERNAL_KEYS``.
511456
512457
Handles:
513458
- Global overrides (apply to all stages)

vllm_omni/entrypoints/cli/serve.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,6 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
137137
help="The address of the Ray cluster to connect to.",
138138
)
139139

140-
# Independent stage launch support
141-
omni_config_group.add_argument(
142-
"--stage-id",
143-
type=int,
144-
default=None,
145-
help="Launch only the specified stage ID for distributed deployments. "
146-
"Use this when deploying stages independently across nodes. "
147-
"Example: --stage-id 0 launches only the first stage.",
148-
)
149-
150140
# Diffusion model specific arguments
151141
omni_config_group.add_argument(
152142
"--num-gpus",

vllm_omni/entrypoints/omni.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None:
203203
batch_timeout = kwargs.get("batch_timeout", 10)
204204
stage_configs_path = kwargs.get("stage_configs_path", None)
205205
log_stats = kwargs.get("log_stats", False)
206-
stage_id = kwargs.get("stage_id", None) # For independent stage launch
207206

208207
### base engine args
209208
tokenizer = kwargs.get("tokenizer", None)
@@ -221,23 +220,6 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None:
221220
self.config_path = stage_configs_path
222221
self.stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=base_engine_args)
223222

224-
# Filter stages if --stage-id is specified (for independent launch).
225-
# NOTE: In independent launch mode the filtered stage occupies list
226-
# index 0 regardless of its original stage_id. This is intentional
227-
# because the stage runs in isolation without cross-stage connectors.
228-
if stage_id is not None:
229-
filtered_configs = [cfg for cfg in self.stage_configs if getattr(cfg, "stage_id", None) == stage_id]
230-
if not filtered_configs:
231-
logger.warning(
232-
f"Stage ID {stage_id} not found in configs. Available IDs: "
233-
f"{[getattr(cfg, 'stage_id', None) for cfg in self.stage_configs]}"
234-
)
235-
else:
236-
logger.info(f"Independent launch mode: loading only stage {stage_id}")
237-
self.stage_configs = (
238-
create_config(filtered_configs) if isinstance(filtered_configs[0], dict) else filtered_configs
239-
)
240-
241223
# Inject diffusion LoRA-related knobs from kwargs if not present in the stage config.
242224
for cfg in self.stage_configs:
243225
try:

vllm_omni/entrypoints/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def extract_runtime_overrides(kwargs: dict[str, Any]) -> dict[str, Any]:
248248
"""Extract Tier-2 runtime parameters from kwargs.
249249
250250
All CLI arguments registered by engine config classes are accepted,
251-
not just the well-known set in ``StageConfigFactory.RUNTIME_PARAMS``.
251+
unless they appear in ``StageConfigFactory._INTERNAL_KEYS``.
252252
Internal / orchestrator-only keys are excluded automatically.
253253
254254
Args:

0 commit comments

Comments
 (0)