Skip to content

Commit cfe77e8

Browse files
authored
[Bugfix]Support Qwen3-MOE on aclgraph mode in sizes capture and add new ut (#2511)
[Bugfix]Support Qwen3-MOE on aclgraph mode in sizes capture and add new ut What this PR does / why we need it? This PR solves the problem of sizes capture and stream error caused by using ACLgraph on the Qwen3-30B MOE model. Add new ut. Does this PR introduce any user-facing change? no How was this patch tested? ut - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@6fad29b Signed-off-by: lilinsiman <[email protected]>
1 parent b3fdd78 commit cfe77e8

File tree

3 files changed

+80
-7
lines changed

3 files changed

+80
-7
lines changed

tests/e2e/multicard/test_qwen3_moe.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
2222
"""
2323

24+
import os
25+
2426
from modelscope import snapshot_download # type: ignore
2527

2628
from tests.e2e.conftest import VllmRunner
@@ -72,3 +74,36 @@ def test_models_distributed_Qwen3_MOE_W8A8():
7274
enforce_eager=False,
7375
) as vllm_model:
7476
vllm_model.generate_greedy(example_prompts, max_tokens)
77+
78+
79+
def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH_AIV():
80+
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
81+
example_prompts = [
82+
"Hello, my name is",
83+
]
84+
dtype = "auto"
85+
max_tokens = 5
86+
with VllmRunner(
87+
"Qwen/Qwen3-30B-A3B",
88+
dtype=dtype,
89+
tensor_parallel_size=2,
90+
enforce_eager=False,
91+
) as vllm_model:
92+
vllm_model.generate_greedy(example_prompts, max_tokens)
93+
94+
95+
def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH():
96+
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
97+
del os.environ['HCCL_OP_EXPANSION_MODE']
98+
example_prompts = [
99+
"Hello, my name is",
100+
]
101+
dtype = "auto"
102+
max_tokens = 5
103+
with VllmRunner(
104+
"Qwen/Qwen3-30B-A3B",
105+
dtype=dtype,
106+
tensor_parallel_size=2,
107+
enforce_eager=False,
108+
) as vllm_model:
109+
vllm_model.generate_greedy(example_prompts, max_tokens)

tests/ut/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,9 @@ def test_update_aclgraph_sizes(self):
255255
parallel_config=test_parallel_config,
256256
)
257257
utils.update_aclgraph_sizes(test_vllm_config)
258+
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
259+
utils.update_aclgraph_sizes(test_vllm_config)
260+
del os.environ['HCCL_OP_EXPANSION_MODE']
258261
self.assertEqual(
259262
147,
260263
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
@@ -267,6 +270,9 @@ def test_update_aclgraph_sizes(self):
267270
parallel_config=test_parallel_config,
268271
)
269272
utils.update_aclgraph_sizes(test_vllm_config)
273+
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
274+
utils.update_aclgraph_sizes(test_vllm_config)
275+
del os.environ['HCCL_OP_EXPANSION_MODE']
270276
self.assertEqual(
271277
3,
272278
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))

vllm_ascend/utils.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import atexit
2121
import functools
2222
import math
23+
import os
2324
from contextlib import contextmanager
2425
from enum import Enum
2526
from threading import Lock
@@ -303,16 +304,47 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
303304
parallel_config = vllm_config.parallel_config
304305

305306
# TODO: Find out whether we need to take into account the pp_size
306-
parallel_factor = 1 + sum(size > 1 for size in [
307-
parallel_config.data_parallel_size_local,
307+
num_comm_groups = sum(size > 1 for size in [
308+
parallel_config.data_parallel_size,
308309
parallel_config.tensor_parallel_size,
309310
])
310311

311-
# Calculate maximum supported batch sizes considering model architecture
312-
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
313-
(num_hidden_layers + 1) / parallel_factor)
314-
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
315-
max_num_batch_sizes)
312+
if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV':
313+
# TODO: Find out whether we need to take into account the pp_size
314+
parallel_factor = 1 + num_comm_groups + int(
315+
parallel_config.enable_expert_parallel)
316+
# Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
317+
# Assume the following case:
318+
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
319+
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
320+
max_num_batch_sizes = math.floor(
321+
MAX_CAPTURE_SIZE / (num_hidden_layers + 1) / parallel_factor)
322+
logger.info(
323+
"Calculated maximum supported batch sizes for ACL graph: %s",
324+
max_num_batch_sizes)
325+
else:
326+
# The above describes an empirical formula applicable to the A2 hardware.
327+
# Under this configuration, HCCL employs the FFTS+ method for execution unfolding,
328+
# which adds only 1 concurrent stream without consuming collective communication execution unfolding streams.
329+
# On A3 hardware, HCCL defaults to the AICPU method.
330+
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case).
331+
# Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes.
332+
# Therefore, the calculation formula has been modified as follows:
333+
# Assume the following case:
334+
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
335+
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
336+
max_num_batch_sizes = math.floor(
337+
(MAX_CAPTURE_SIZE - num_comm_groups * 40) /
338+
(num_hidden_layers + 1) / (1 + num_comm_groups * 2))
339+
logger.info(
340+
"Calculated maximum supported batch sizes for ACL graph: %s",
341+
max_num_batch_sizes)
342+
logger.warning(
343+
"Currently, communication is performed using FFTS+ method, which reduces "
344+
"the number of available streams and, as a result, limits the range of runtime "
345+
"shapes that can be handled. To both improve communication performance and "
346+
"increase the number of supported shapes, set HCCL_OP_EXPANSION_MODE=AIV."
347+
)
316348

317349
# If original sizes exceed maximum, sample a representative subset
318350
if max_num_batch_sizes < len(original_sizes):

0 commit comments

Comments
 (0)