Skip to content

Commit d41a26d

Browse files
committed
fix_A3_ACLgraph_sizes_capture_bug_and_add_new_ut
Signed-off-by: lilinsiman <[email protected]>
1 parent 992271b commit d41a26d

File tree

2 files changed

+55
-11
lines changed

2 files changed

+55
-11
lines changed

tests/e2e/multicard/test_qwen3_moe.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""
2323

2424
from modelscope import snapshot_download # type: ignore
25-
25+
import os
2626
from tests.e2e.conftest import VllmRunner
2727

2828

@@ -72,3 +72,36 @@ def test_models_distributed_Qwen3_MOE_W8A8():
7272
enforce_eager=False,
7373
) as vllm_model:
7474
vllm_model.generate_greedy(example_prompts, max_tokens)
75+
76+
77+
def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH_AIV():
78+
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
79+
example_prompts = [
80+
"Hello, my name is",
81+
]
82+
dtype = "auto"
83+
max_tokens = 5
84+
with VllmRunner(
85+
"Qwen/Qwen3-30B-A3B",
86+
dtype=dtype,
87+
tensor_parallel_size=2,
88+
enforce_eager=False,
89+
) as vllm_model:
90+
vllm_model.generate_greedy(example_prompts, max_tokens)
91+
92+
93+
def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH():
94+
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
95+
del os.environ['HCCL_OP_EXPANSION_MODE']
96+
example_prompts = [
97+
"Hello, my name is",
98+
]
99+
dtype = "auto"
100+
max_tokens = 5
101+
with VllmRunner(
102+
"Qwen/Qwen3-30B-A3B",
103+
dtype=dtype,
104+
tensor_parallel_size=2,
105+
enforce_eager=False,
106+
) as vllm_model:
107+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/utils.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -325,17 +325,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
325325
num_hidden_layers = get_max_hidden_layers(hf_config)
326326
parallel_config = vllm_config.parallel_config
327327

328+
if os.getenv("HCCL_OP_EXPANSION_MODE")=='AIV':
328329
# TODO: Find out whether we need to take into account the pp_size
329-
parallel_factor = 1 + sum(size > 1 for size in [
330-
parallel_config.data_parallel_size_local,
331-
parallel_config.tensor_parallel_size,
332-
])
333-
334-
# Calculate maximum supported batch sizes considering model architecture
335-
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
336-
(num_hidden_layers + 1) / parallel_factor)
337-
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
338-
max_num_batch_sizes)
330+
parallel_factor = 1 + sum(size > 1 for size in [
331+
parallel_config.data_parallel_size,
332+
parallel_config.tensor_parallel_size,
333+
])
334+
335+
# Calculate maximum supported batch sizes considering model architecture
336+
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
337+
(num_hidden_layers + 1) / parallel_factor)
338+
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
339+
max_num_batch_sizes)
340+
else:
341+
num_comm_groups = sum(size > 1 for size in [
342+
parallel_config.data_parallel_size,
343+
parallel_config.tensor_parallel_size,
344+
])
345+
346+
max_num_batch_sizes = math.floor((MAX_CAPTURE_SIZE - num_comm_groups * 40) / (num_hidden_layers + 1) / (1 + num_comm_groups * 2))
347+
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
348+
max_num_batch_sizes)
349+
logger.warning("Unset HCCL_OP_EXPANSION_MODE prevents max size capture. Setting HCCL_OP_EXPANSION_MODE=AIV captures max sizes and boosts ACL graph performance.")
339350

340351
# If original sizes exceed maximum, sample a representative subset
341352
if max_num_batch_sizes < len(original_sizes):

0 commit comments

Comments
 (0)