Skip to content

Commit 36b8f80

Browse files
wuhang2014princeprideCopilot
authored
[Feature] Support Stage Based Deployment CLI (#939)
Signed-off-by: wuhang <wuhang6@huawei.com> Signed-off-by: princepride <wangzhipeng628@gmail.com> Signed-off-by: wuhang <whlbx@hotmail.com> Co-authored-by: 汪志鹏 <wangzhipeng628@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 06d347b commit 36b8f80

File tree

11 files changed

+1272
-198
lines changed

11 files changed

+1272
-198
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/bin/bash
2+
# Bagel multi-stage online serving startup script
3+
# Starts stage 0 as master with API server, and stage 1 in headless mode
4+
5+
MODEL="${MODEL:-ByteDance-Seed/BAGEL-7B-MoT}"
6+
PORT="${PORT:-8091}"
7+
MASTER_ADDRESS="${MASTER_ADDRESS:-127.0.0.1}"
8+
MASTER_PORT="${MASTER_PORT:-8092}"
9+
STAGE_CONFIGS_PATH="$(dirname "$0")/../../../vllm_omni/model_executor/stage_configs/bagel.yaml"
10+
11+
echo "Starting Bagel multi-stage server..."
12+
echo "Model: $MODEL"
13+
echo "API Port: $PORT"
14+
echo "Master Address: $MASTER_ADDRESS"
15+
echo "Master Port: $MASTER_PORT"
16+
echo "Stage Configs: $STAGE_CONFIGS_PATH"
17+
18+
# Start stage 1 (DiT) in headless mode first
19+
echo "Starting Stage 1 (DiT) in headless mode..."
20+
vllm serve "$MODEL" --omni \
21+
--stage-configs-path "$STAGE_CONFIGS_PATH" \
22+
--stage-id 1 \
23+
--headless \
24+
-oma "$MASTER_ADDRESS" \
25+
-omp "$MASTER_PORT" &
26+
27+
# Start stage 0 (Thinker) as master with API server
28+
echo "Starting Stage 0 (Thinker) as master..."
29+
vllm serve "$MODEL" --omni \
30+
--port "$PORT" \
31+
--stage-configs-path "$STAGE_CONFIGS_PATH" \
32+
--stage-id 0 \
33+
-oma "$MASTER_ADDRESS" \
34+
-omp "$MASTER_PORT"

requirements/common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ imageio[ffmpeg]>=2.37.2
1414
sox>=1.5.0
1515
prettytable>=3.8.0
1616
aenum==3.1.16
17+
pyzmq>=25.0.0

tests/entrypoints/test_async_omni_diffusion_config.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33

44
import pytest
55

6-
from vllm_omni.entrypoints import omni as omni_module
6+
from vllm_omni.entrypoints import utils as utils_module
77
from vllm_omni.entrypoints.async_omni import AsyncOmni
88

99
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
1010

11+
MODEL = "riverclouds/qwen_image_random"
12+
1113

1214
def test_default_stage_config_includes_cache_backend(monkeypatch):
1315
"""Ensure cache_backend/cache_config are preserved in default diffusion stage."""
14-
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: [])
15-
monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None)
16+
monkeypatch.setattr(utils_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: [])
17+
monkeypatch.setattr(utils_module, "resolve_model_config_path", lambda model: None)
1618
monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None)
1719
monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None)
1820

1921
omni = AsyncOmni(
20-
model="dummy-model",
22+
model=MODEL,
2123
cache_backend="cache_dit",
2224
cache_config='{"Fn_compute_blocks": 2}',
2325
vae_use_slicing=True,
@@ -41,13 +43,13 @@ def test_default_stage_config_includes_cache_backend(monkeypatch):
4143

4244
def test_default_cache_config_used_when_missing(monkeypatch):
4345
"""Ensure default cache_config is applied when cache_backend is set."""
44-
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: [])
45-
monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None)
46+
monkeypatch.setattr(utils_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: [])
47+
monkeypatch.setattr(utils_module, "resolve_model_config_path", lambda model: None)
4648
monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None)
4749
monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None)
4850

4951
omni = AsyncOmni(
50-
model="dummy-model",
52+
model=MODEL,
5153
cache_backend="cache_dit",
5254
)
5355

@@ -59,13 +61,13 @@ def test_default_cache_config_used_when_missing(monkeypatch):
5961

6062
def test_default_stage_devices_from_sequence_parallel(monkeypatch):
6163
"""Ensure devices list reflects sequence parallel size when no parallel_config is provided."""
62-
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: [])
63-
monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None)
64+
monkeypatch.setattr(utils_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: [])
65+
monkeypatch.setattr(utils_module, "resolve_model_config_path", lambda model: None)
6466
monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None)
6567
monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None)
6668

6769
omni = AsyncOmni(
68-
model="dummy-model",
70+
model=MODEL,
6971
ulysses_degree=2,
7072
ring_degree=2,
7173
)

0 commit comments

Comments
 (0)