Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def set_ascend_forward_context(

from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method

max_num_tokens = int(num_tokens_across_dp.max().item()) if num_tokens_across_dp is not None else num_tokens
moe_comm_type = select_moe_comm_method(max_num_tokens, vllm_config, is_draft_model)

moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change removes the logic for determining the maximum number of tokens across data parallel (DP) ranks. It now relies on the num_tokens argument to be consistent across all DP ranks.

However, num_tokens may not be consistent. Specifically, in NPUModelRunner._sync_batch_across_dp, the all_reduce operation is skipped if _skip_all_reduce_across_dp_group() returns true (e.g., for non-MoE models or certain MoE configurations). In this case, num_tokens passed to this function will be the local token count for each rank, which can be different.

This will cause select_moe_comm_method to be called with different num_tokens values on different DP ranks, potentially leading to desynchronization and a hang if they choose different communication methods. This is a critical issue.

While the previous logic was also affected by the issue in _skip_all_reduce_across_dp_group, it correctly showed the intent of using a synchronized maximum token count. This change removes that safeguard.

Suggested change
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
max_num_tokens = int(num_tokens_across_dp.max().item()) if num_tokens_across_dp is not None else num_tokens
moe_comm_type = select_moe_comm_method(max_num_tokens, vllm_config, is_draft_model)

forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)

Expand Down
18 changes: 5 additions & 13 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,6 @@ def execute_model(
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=max_num_scheduled_tokens,
use_cascade_attn=cascade_attn_prefix_lens is not None,
force_eager=self.model_config.enforce_eager,
num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs),
)

Expand Down Expand Up @@ -1854,7 +1853,6 @@ def _sync_batch_across_dp(
self,
num_tokens_padded: int | None = None,
cudagraph_mode: int = 0,
allow_dp_padding: bool = False,
) -> tuple[bool, torch.Tensor | None, int]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
Expand Down Expand Up @@ -1898,16 +1896,11 @@ def _sync_batch_across_dp(

num_tokens_across_dp = tensor[0, :]
max_num_tokens = int(num_tokens_across_dp.max().item())

if allow_dp_padding:
num_tokens_after_padding = torch.tensor(
[max_num_tokens] * len(num_tokens_across_dp),
device="cpu",
dtype=torch.int32,
)
else:
num_tokens_after_padding = num_tokens_across_dp.cpu()

num_tokens_after_padding = torch.tensor(
[max_num_tokens] * len(num_tokens_across_dp),
device="cpu",
dtype=torch.int32,
)
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
return False, num_tokens_after_padding, synced_cudagraph_mode
Expand Down Expand Up @@ -1976,7 +1969,6 @@ def dispatch_cudagraph(num_tokens, disable_full=False, valid_modes=None):
_, num_tokens_across_dp, synced_cudagraph_mode = self._sync_batch_across_dp(
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
allow_dp_padding=cudagraph_mode != CUDAGraphMode.NONE,
)

# Extract DP padding if there is any
Expand Down
Loading