Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tests/e2e/singlecard/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,15 @@ def test_fused_experts(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)

output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())

output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk,
e_map)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
Expand Down Expand Up @@ -148,7 +155,7 @@ def test_select_experts(
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)

topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
Expand All @@ -169,6 +176,7 @@ def test_select_experts(
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
assert row_idx.shape == (m, topk)


@pytest.mark.parametrize("device", DEVICE)
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/ops/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,

x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=2,
Expand Down
46 changes: 23 additions & 23 deletions tests/ut/quantization/test_w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,25 +719,25 @@ def setUp(self):
def test_softmax_scoring(self):
"""Test softmax scoring function"""

weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")

self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))

def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""

weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")

self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
Expand All @@ -760,13 +760,13 @@ def test_grouped_topk(self, mock_topk):
self.top_k,
dtype=torch.long))

weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)

mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
Expand All @@ -780,7 +780,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
self.num_experts)

e_score_correction_bias = torch.randn(self.num_experts)
weights, ids = select_experts(
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand All @@ -803,7 +803,7 @@ def test_custom_routing_function(self):
self.top_k,
dtype=torch.int32))

weights, ids = select_experts(
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand All @@ -824,7 +824,7 @@ def test_renormalize(self, mock_topk):
self.top_k,
dtype=torch.long))

weights, _ = select_experts(
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand All @@ -844,7 +844,7 @@ def test_output_dtypes(self, mock_topk):
self.top_k,
dtype=torch.long))

weights, ids = select_experts(
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand Down
7 changes: 7 additions & 0 deletions tests/ut/quantization/test_w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def test_fused_experts_with_all2all(self, mock_moe_init_routing,
torch.randn(self.num_tokens),
)
mock_moe_finalize_routing.return_value = self.placeholder
row_idx_len = self.num_tokens * 8
row_idx = (torch.arange(
0,
row_idx_len,
dtype=torch.int32,
).view(8, -1).permute(1, 0).contiguous())

result = fused_experts_with_all2all(
hidden_states=self.placeholder,
Expand All @@ -64,6 +70,7 @@ def test_fused_experts_with_all2all(self, mock_moe_init_routing,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
row_idx=row_idx,
top_k=8,
expert_map=expert_map,
ep_group=ep_group,
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward_oot(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:

topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand Down
31 changes: 7 additions & 24 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def fused_experts_with_all2all(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
Expand All @@ -401,17 +402,10 @@ def fused_experts_with_all2all(

num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
device = hidden_states.device

if expert_map is not None:
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
Expand Down Expand Up @@ -445,12 +439,6 @@ def fused_experts_with_all2all(

hidden_states = hidden_states[sorted_idx]
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=topk_weights.device).view(
top_k, -1).permute(1, 0).contiguous()
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
Expand Down Expand Up @@ -524,6 +512,7 @@ def fused_experts_with_all2all_buffer(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
max_model_len: int,
global_batch_size: int,
Expand All @@ -535,14 +524,10 @@ def fused_experts_with_all2all_buffer(
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

num_tokens, _ = hidden_states.shape
device = hidden_states.device

global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
device=device).view(top_k,
-1).permute(1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
Expand Down Expand Up @@ -755,6 +740,7 @@ def fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
Expand Down Expand Up @@ -846,12 +832,6 @@ def fused_experts(
# Rearrange hidden_states
sorted_hidden_states = hidden_states[sorted_token_indices]
else:
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
Expand Down Expand Up @@ -975,7 +955,7 @@ def apply(
**kwargs,
) -> torch.Tensor:

topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand Down Expand Up @@ -1019,6 +999,7 @@ def apply(
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map)
elif MOE_ALL2ALL_BUFFER:
Expand All @@ -1028,6 +1009,7 @@ def apply(
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
max_model_len=self.max_model_len,
global_batch_size=self.global_batch_size,
Expand All @@ -1049,6 +1031,7 @@ def apply(
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map,
ep_group=get_ep_group())
Expand Down
27 changes: 21 additions & 6 deletions vllm_ascend/ops/layers/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@
import torch_npu


def return_row_idx(hidden_states, top_k):
num_tokens = hidden_states.shape[0]
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=hidden_states.device).view(
top_k, -1).permute(1, 0).contiguous())
return row_idx


def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
Expand Down Expand Up @@ -56,7 +67,8 @@ def select_experts(hidden_states: torch.Tensor,
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""

topk_weights, topk_ids = _select_experts_with_fusion_ops(
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
Expand All @@ -83,7 +95,9 @@ def select_experts(hidden_states: torch.Tensor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
return topk_weights, topk_ids
if row_idx is None:
row_idx = return_row_idx(hidden_states, top_k)
return topk_weights, topk_ids, row_idx


def _native_grouped_topk(
Expand Down Expand Up @@ -156,6 +170,7 @@ def _select_expert_use_group_topk(


def _select_experts_with_fusion_ops(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
Expand All @@ -168,7 +183,7 @@ def _select_experts_with_fusion_ops(
global_num_experts: int = -1,
is_unquantized: bool = False):

topk_weights, topk_ids = None, None
topk_weights, topk_ids, row_idx = None, None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
is_deepseek_v3_r1 = global_num_experts == 256
if is_deepseek_v3_r1:
Expand All @@ -186,14 +201,14 @@ def _select_experts_with_fusion_ops(
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))

row_idx = return_row_idx(hidden_states, top_k)
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)

return topk_weights, topk_ids
return topk_weights, topk_ids, row_idx


def _native_select_experts(
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/quantization/w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def apply(
1] == global_num_experts, "Number of global experts mismatch"

# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand Down Expand Up @@ -334,6 +334,7 @@ def apply(
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
top_k=top_k,
expert_map=expert_map,
ep_group=self.ep_group,
Expand Down
Loading
Loading