Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tests/e2e/multicard/test_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
"""

import os

from modelscope import snapshot_download # type: ignore

from tests.e2e.conftest import VllmRunner
Expand Down Expand Up @@ -72,3 +74,36 @@ def test_models_distributed_Qwen3_MOE_W8A8():
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH_AIV():
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
example_prompts = [
"Hello, my name is",
]
dtype = "auto"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
example_prompts = [
"Hello, my name is",
]
dtype = "auto"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
# Please make sure that the version is correct.
"SOC_VERSION":
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
# location for orchestrated deployment of communication algorithms.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a env from HCCL, we should not add it in vllm-ascend. we can set it in docker file and mention it in doc.

"HCCL_OP_EXPANSION_MODE":
lambda: os.environ.get("HCCL_OP_EXPANSION_MODE", None),
# If set, vllm-ascend will print verbose logs during compilation
"VERBOSE":
lambda: bool(int(os.getenv('VERBOSE', '0'))),
Expand Down
47 changes: 38 additions & 9 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,17 +325,46 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
num_hidden_layers = get_max_hidden_layers(hf_config)
parallel_config = vllm_config.parallel_config

# TODO: Find out whether we need to take into account the pp_size
parallel_factor = 1 + sum(size > 1 for size in [
parallel_config.data_parallel_size_local,
num_comm_groups = sum(size > 1 for size in [
parallel_config.data_parallel_size,
parallel_config.tensor_parallel_size,
])

# Calculate maximum supported batch sizes considering model architecture
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
(num_hidden_layers + 1) / parallel_factor)
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
max_num_batch_sizes)
if envs_ascend.HCCL_OP_EXPANSION_MODE == 'AIV':
# TODO: Find out whether we need to take into account the pp_size
parallel_factor = 1 + num_comm_groups + int(
parallel_config.enable_expert_parallel)
# Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
# Assume the following case:
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
max_num_batch_sizes = math.floor(
MAX_CAPTURE_SIZE / (num_hidden_layers + 1) / parallel_factor)
logger.info(
"Calculated maximum supported batch sizes for ACL graph: %s",
max_num_batch_sizes)
else:
# The above describes an empirical formula applicable to the A2 hardware.
# Under this configuration, HCCL employs the FFTS+ method for execution unfolding,
# which adds only 1 concurrent stream without consuming collective communication execution unfolding streams.
# On A3 hardware, HCCL defaults to the AICPU method.
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case).
# Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes.
# Therefore, the calculation formula has been modified as follows:
# Assume the following case:
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
max_num_batch_sizes = math.floor(
(MAX_CAPTURE_SIZE - num_comm_groups * 40) /
(num_hidden_layers + 1) / (1 + num_comm_groups * 2))
logger.info(
"Calculated maximum supported batch sizes for ACL graph: %s",
max_num_batch_sizes)
logger.warning(
"Currently, communication is performed using FFTS+ method, which reduces "
"the number of available streams and, as a result, limits the range of runtime "
"shapes that can be handled. To both improve communication performance and "
"increase the number of supported shapes, set HCCL_OP_EXPANSION_MODE=AIV."
)

# If original sizes exceed maximum, sample a representative subset
if max_num_batch_sizes < len(original_sizes):
Expand Down
Loading