Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/e2e/singlecard/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_select_experts(
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)

topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
Expand All @@ -169,6 +169,8 @@ def test_select_experts(
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
assert row_idx.shape == (m, topk)
assert row_idx.dtype == torch.int32


@pytest.mark.parametrize("device", DEVICE)
Expand Down
3 changes: 2 additions & 1 deletion tests/ut/ops/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,

x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add a test case to cover the fused_experts function with the row_idx parameter passed in.

hidden_states=x,
router_logits=router_logits,
top_k=2,
Expand All @@ -415,3 +415,4 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,

assert topk_weights.shape == (8, 2)
assert topk_ids.shape == (8, 2)
assert row_idx.shape == (8, 2)
46 changes: 23 additions & 23 deletions tests/ut/quantization/test_w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,25 +719,25 @@ def setUp(self):
def test_softmax_scoring(self):
"""Test softmax scoring function"""

weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")

self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))

def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""

weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")

self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
Expand All @@ -760,13 +760,13 @@ def test_grouped_topk(self, mock_topk):
self.top_k,
dtype=torch.long))

weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)

mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
Expand All @@ -780,7 +780,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
self.num_experts)

e_score_correction_bias = torch.randn(self.num_experts)
weights, ids = select_experts(
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand All @@ -803,7 +803,7 @@ def test_custom_routing_function(self):
self.top_k,
dtype=torch.int32))

weights, ids = select_experts(
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand All @@ -824,7 +824,7 @@ def test_renormalize(self, mock_topk):
self.top_k,
dtype=torch.long))

weights, _ = select_experts(
weights, _, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand All @@ -844,7 +844,7 @@ def test_output_dtypes(self, mock_topk):
self.top_k,
dtype=torch.long))

weights, ids = select_experts(
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward_oot(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:

topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand Down
19 changes: 11 additions & 8 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def fused_experts(
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
max_num_tokens: Optional[int] = None,
row_idx: torch.Tensor = None,
) -> torch.Tensor:
"""
Fused experts with top-k routing.
Expand Down Expand Up @@ -846,12 +847,13 @@ def fused_experts(
# Rearrange hidden_states
sorted_hidden_states = hidden_states[sorted_token_indices]
else:
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
if row_idx is None:
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems all layers use the same row_idx, can we just construct it once?

row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
Expand Down Expand Up @@ -975,7 +977,7 @@ def apply(
**kwargs,
) -> torch.Tensor:

topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand Down Expand Up @@ -1020,7 +1022,8 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
expert_map=expert_map,
row_idx=row_idx)
elif MOE_ALL2ALL_BUFFER:
return fused_experts_with_all2all_buffer(
hidden_states=x,
Expand Down
13 changes: 6 additions & 7 deletions vllm_ascend/ops/layers/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def select_experts(hidden_states: torch.Tensor,
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""

topk_weights, topk_ids = _select_experts_with_fusion_ops(
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
Expand All @@ -83,7 +83,7 @@ def select_experts(hidden_states: torch.Tensor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
return topk_weights, topk_ids
return topk_weights, topk_ids, row_idx


def _native_grouped_topk(
Expand Down Expand Up @@ -168,7 +168,7 @@ def _select_experts_with_fusion_ops(
global_num_experts: int = -1,
is_unquantized: bool = False):

topk_weights, topk_ids = None, None
topk_weights, topk_ids, row_idx = None, None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
is_deepseek_v3_r1 = global_num_experts == 256
if is_deepseek_v3_r1:
Expand All @@ -186,14 +186,13 @@ def _select_experts_with_fusion_ops(
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))

if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
elif not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This if block can be executed even if the preceding if is_deepseek_v3_r1: block was already executed. This would cause the results from the specialized npu_moe_gating_top_k path to be overwritten by this more general npu_moe_gating_top_k_softmax path, which is likely not the intended behavior. To ensure that only one of these specialized fusion paths is taken, consider changing the if on line 190 to an elif.

Copy link
Contributor Author

@loukong33 loukong33 Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)

return topk_weights, topk_ids
return topk_weights, topk_ids, row_idx


def _native_select_experts(
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/quantization/w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def apply(
1] == global_num_experts, "Number of global experts mismatch"

# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/quantization/w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def apply(
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"

topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ def apply(
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"

topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand Down
Loading