Skip to content

Commit 875a86c

Browse files
authored
ut: add example and e2e test for sleepmode in external_launcher (#2152)
### What this PR does / why we need it? This pr add e2e testcase to make sure sleep mode in external_launcher is ok. ### Does this PR introduce _any_ user-facing change? not involved ### How was this patch tested? not involved - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@74333ae Signed-off-by: huangxialu <[email protected]>
1 parent 8a59367 commit 875a86c

File tree

2 files changed

+132
-8
lines changed

2 files changed

+132
-8
lines changed

examples/offline_external_launcher.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@
2828
--proc-per-node=2
2929
MOE models:
3030
python examples/offline_external_launcher.py \
31-
--model="Qwen/Qwen3-0.6B" \
31+
--model="Qwen/Qwen3-30B-A3B" \
3232
--tp-size=2 \
3333
--proc-per-node=2 \
3434
--enable-expert-parallel
3535
3636
Multi-node:
3737
Node 0 (assume the node has ip of 10.99.48.128):
3838
python examples/offline_external_launcher.py \
39-
--model="Qwen/Qwen3-0.6B" \
39+
--model="Qwen/Qwen3-30B-A3B" \
4040
--tp-size=2 \
4141
--node-size=2 \
4242
--node-rank=0 \
@@ -46,7 +46,7 @@
4646
--master-port=13345
4747
Node 1:
4848
python examples/offline_external_launcher.py \
49-
--model="Qwen/Qwen3-0.6B" \
49+
--model="Qwen/Qwen3-30B-A3B" \
5050
--tp-size=2 \
5151
--node-size=2 \
5252
--node-rank=1 \
@@ -66,7 +66,7 @@
6666
from vllm import LLM, SamplingParams
6767
from vllm.distributed.parallel_state import ( # noqa E402
6868
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
69-
from vllm.utils import get_open_port
69+
from vllm.utils import get_open_port, GiB_bytes
7070

7171
os.environ["VLLM_USE_MODELSCOPE"] = "True"
7272
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -114,20 +114,44 @@ def parse_args():
114114
parser.add_argument("--enable-expert-parallel",
115115
action="store_true",
116116
help="Enable expert parallel, used in MOE models.")
117-
return parser.parse_args()
117+
parser.add_argument("--enable-sleep-mode",
118+
action="store_true",
119+
help="Enable sleep mode for the engine.")
120+
parser.add_argument("--temperature",
121+
type=float,
122+
default=0.8,
123+
help="Float that controls the randomness of the sampling.")
124+
parser.add_argument("--model-weight-gib",
125+
type=float,
126+
default=None,
127+
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
128+
129+
args = parser.parse_args()
130+
if args.enable_sleep_mode:
131+
if args.model_weight_gib is None or args.temperature != 0:
132+
parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.")
133+
if args.model_weight_gib <= 0:
134+
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
135+
if args.model == parser.get_default("model") and args.model_weight_gib is None:
136+
parser.error("model-weight-gib must be provided for default model when enable-sleep-mode is set.")
137+
138+
return args
118139

119140

120141
def main(
121142
local_rank: int,
122143
rank: int,
123144
master_addr: str,
124145
master_port: int,
146+
model_weight_gib: float,
125147
model: str = "Qwen/Qwen3-0.6B",
126148
world_size: int = 4,
127149
tensor_parallel_size: int = 2,
128150
enable_expert_parallel: bool = False,
129151
enforce_eager: bool = False,
130152
trust_remote_code: bool = True,
153+
enable_sleep_mode: bool = False,
154+
temperature: float = 0.8,
131155
):
132156
os.environ["MASTER_ADDR"] = master_addr
133157
os.environ["MASTER_PORT"] = str(master_port)
@@ -147,7 +171,7 @@ def main(
147171
"The future of AI is",
148172
] * 10
149173
sampling_params = SamplingParams(
150-
temperature=0.8,
174+
temperature=temperature,
151175
top_p=0.95,
152176
max_tokens=10,
153177
)
@@ -159,10 +183,31 @@ def main(
159183
trust_remote_code=trust_remote_code,
160184
distributed_executor_backend="external_launcher",
161185
seed=0,
186+
enable_sleep_mode=enable_sleep_mode,
162187
)
163188
tp_ranks = get_tp_group().ranks
164189
print(f'TP RANKS: {tp_ranks}')
190+
165191
outputs = llm.generate(prompts, sampling_params)
192+
193+
if enable_sleep_mode:
194+
if rank == 0:
195+
free_bytes_before_sleep, total = torch.npu.mem_get_info()
196+
llm.sleep(level=1)
197+
if rank == 0:
198+
free_bytes_after_sleep, total = torch.npu.mem_get_info()
199+
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
200+
print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB")
201+
# now the freed memory should be larger than the model weights
202+
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
203+
204+
llm.wake_up()
205+
outputs_after_wakeup = llm.generate(prompts, sampling_params)
206+
if rank == 0:
207+
# cmp output
208+
assert outputs[0].outputs[0].text == outputs_after_wakeup[0].outputs[0].text
209+
print("Sleep and wake up successfully!!")
210+
166211
for i, output in enumerate(outputs):
167212
if i >= 5:
168213
# print only 5 outputs
@@ -214,12 +259,15 @@ def cleanup_env_and_memory():
214259
rank,
215260
master_addr,
216261
master_port,
262+
args.model_weight_gib,
217263
args.model,
218264
world_size,
219265
tp_size,
220266
args.enable_expert_parallel,
221267
args.enforce_eager,
222268
args.trust_remote_code,
269+
args.enable_sleep_mode,
270+
args.temperature,
223271
))
224272

225273
proc.start()

tests/e2e/multicard/test_external_launcher.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,14 @@
2424
import subprocess
2525
import sys
2626
from pathlib import Path
27-
from unittest.mock import patch
2827

2928
import pytest
3029

3130
MODELS = ["Qwen/Qwen3-0.6B"]
31+
MOE_MODELS = ["Qwen/Qwen3-30B-A3B"]
3232

3333

3434
@pytest.mark.parametrize("model", MODELS)
35-
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"})
3635
def test_external_launcher(model):
3736
script = Path(
3837
__file__
@@ -71,3 +70,80 @@ def test_external_launcher(model):
7170
assert "TP RANKS: [1]" in output
7271
assert "Generated text:" in output
7372
assert proc.returncode == 0
73+
74+
75+
@pytest.mark.parametrize("model", MOE_MODELS)
76+
def test_moe_external_launcher(model):
77+
script = Path(
78+
__file__
79+
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
80+
env = os.environ.copy()
81+
# TODO: Change to 2 when ci machine has 4 cards
82+
cmd = [
83+
sys.executable,
84+
str(script), "--model", model, "--tp-size", "2", "--node-size", "1",
85+
"--node-rank", "0", "--proc-per-node", "2", "--trust-remote-code",
86+
"--enable-expert-parallel"
87+
]
88+
89+
print(f"Running subprocess: {' '.join(cmd)}")
90+
proc = subprocess.run(
91+
cmd,
92+
env=env,
93+
stdout=subprocess.PIPE,
94+
stderr=subprocess.STDOUT,
95+
timeout=600,
96+
)
97+
output = proc.stdout.decode()
98+
99+
print(output)
100+
101+
assert "TP RANKS: [0, 1]" in output
102+
assert "Generated text:" in output
103+
assert proc.returncode == 0
104+
105+
106+
def test_external_launcher_and_sleepmode():
107+
script = Path(
108+
__file__
109+
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
110+
env = os.environ.copy()
111+
# TODO: Change to 2 when ci machine has 4 cards
112+
cmd = [
113+
sys.executable,
114+
str(script),
115+
"--model",
116+
"Qwen/Qwen3-8B",
117+
"--tp-size",
118+
"1",
119+
"--node-size",
120+
"1",
121+
"--node-rank",
122+
"0",
123+
"--proc-per-node",
124+
"2",
125+
"--trust-remote-code",
126+
"--enable-sleep-mode",
127+
"--temperature",
128+
"0",
129+
"--model-weight-gib",
130+
"16",
131+
]
132+
133+
print(f"Running subprocess: {' '.join(cmd)}")
134+
proc = subprocess.run(
135+
cmd,
136+
env=env,
137+
stdout=subprocess.PIPE,
138+
stderr=subprocess.STDOUT,
139+
timeout=300,
140+
)
141+
output = proc.stdout.decode()
142+
143+
print(output)
144+
145+
assert "TP RANKS: [0]" in output
146+
assert "TP RANKS: [1]" in output
147+
assert "Generated text:" in output
148+
assert "Sleep and wake up successfully!!" in output
149+
assert proc.returncode == 0

0 commit comments

Comments
 (0)