Skip to content

Commit e14f2ef

Browse files
shiyuan680yangcheng (AJ)
andauthored
refactor select_experts of moe module (#2150)
### What this PR does / why we need it? this pr refactor select_experts of moe module i merge implementations of quantitative and non-quantitative method in a new class use such as vllm like ExpertsSelector.select_experts ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? test in qwen3-moe and all ut. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@e188592 Signed-off-by: yangcheng <[email protected]> Co-authored-by: yangcheng (AJ) <[email protected]>
1 parent 103654c commit e14f2ef

File tree

10 files changed

+359
-370
lines changed

10 files changed

+359
-370
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import torch
2727
from vllm.model_executor.layers.activation import SiluAndMul
2828

29-
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
29+
from vllm_ascend.ops.fused_moe import fused_experts
30+
from vllm_ascend.ops.layers.experts_selector import select_experts
3031

3132
NUM_EXPERTS = [8, 64]
3233
EP_SIZE = [1, 4]
@@ -142,7 +143,7 @@ def test_select_experts(
142143
dtype=torch.int32)
143144
custom_routing_function.return_value = (mock_weights, mock_ids)
144145

145-
with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
146+
with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk"
146147
) as mock_native_grouped_topk:
147148
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
148149
x)

tests/ut/ops/test_fused_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
2626
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
2727
AscendUnquantizedFusedMoEMethod)
28+
from vllm_ascend.ops.layers.experts_selector import select_experts
2829
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
2930

3031
adapt_patch(True)
@@ -389,3 +390,28 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
389390
assert result.shape == (16, 2)
390391
else:
391392
assert result.shape == x.shape
393+
394+
395+
class TestExpertsSelector:
396+
397+
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
398+
def test_select_experts(self, mock_dist_env, mock_moe_env,
399+
global_num_experts):
400+
401+
x = torch.randn(8, 2)
402+
router_logits = torch.randn(8, 2)
403+
topk_weights, topk_ids = select_experts(
404+
hidden_states=x,
405+
router_logits=router_logits,
406+
top_k=2,
407+
use_grouped_topk=False,
408+
renormalize=True,
409+
topk_group=None,
410+
num_expert_group=None,
411+
custom_routing_function=None,
412+
scoring_func="softmax",
413+
e_score_correction_bias=None,
414+
global_num_experts=global_num_experts)
415+
416+
assert topk_weights.shape == (8, 2)
417+
assert topk_ids.shape == (8, 2)

tests/ut/quantization/test_w8a8.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
from tests.ut.base import TestBase
77
from vllm_ascend.attention.attention_v1 import AscendAttentionState
8+
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
9+
select_experts)
810
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
911
AscendW8A8FusedMoEMethod,
1012
AscendW8A8LinearMethod,
1113
fused_experts, fused_experts_310p,
12-
native_grouped_topk,
13-
quant_per_tensor, select_experts)
14+
quant_per_tensor)
1415

1516

1617
class TestQuantPerTensor(TestBase):
@@ -772,7 +773,7 @@ def test_grouped_topk(self, mock_topk):
772773
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
773774
self.assertEqual(ids.dtype, torch.int32)
774775

775-
@patch('vllm_ascend.quantization.w8a8.native_grouped_topk')
776+
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
776777
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
777778
"""Test grouped topk with expert score correction bias"""
778779
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
@@ -868,9 +869,9 @@ def test_basic_group_selection(self):
868869

869870
with patch('torch.topk',
870871
return_value=(None, expected_topk_indices)) as mock_topk:
871-
result = native_grouped_topk(topk_weights=topk_weights,
872-
num_expert_group=2,
873-
topk_group=2)
872+
result = _native_grouped_topk(topk_weights=topk_weights,
873+
num_expert_group=2,
874+
topk_group=2)
874875

875876
mock_topk.assert_called_once()
876877

@@ -885,9 +886,9 @@ def test_partial_group_selection(self):
885886
expected_topk_indices = torch.tensor([[0], [1]])
886887

887888
with patch('torch.topk', return_value=(None, expected_topk_indices)):
888-
result = native_grouped_topk(topk_weights=topk_weights,
889-
num_expert_group=2,
890-
topk_group=1)
889+
result = _native_grouped_topk(topk_weights=topk_weights,
890+
num_expert_group=2,
891+
topk_group=1)
891892

892893
expected_result = torch.tensor(
893894
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
@@ -900,7 +901,7 @@ def test_single_group(self):
900901
expected_topk_indices = torch.tensor([[0], [0]])
901902

902903
with patch('torch.topk', return_value=(None, expected_topk_indices)):
903-
result = native_grouped_topk(topk_weights=topk_weights,
904-
num_expert_group=1,
905-
topk_group=1)
904+
result = _native_grouped_topk(topk_weights=topk_weights,
905+
num_expert_group=1,
906+
topk_group=1)
906907
self.assertTrue(result.numel() > 0)

vllm_ascend/ops/common_fused_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
UnquantizedFusedMoEMethod
2525

2626
from vllm_ascend.ascend_config import get_ascend_config
27-
from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts,
28-
unified_fused_experts)
27+
from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts
28+
from vllm_ascend.ops.layers.experts_selector import select_experts
2929
from vllm_ascend.utils import is_310p
3030

3131
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -59,7 +59,7 @@ def forward_oot(
5959
custom_routing_function: Optional[Callable] = None,
6060
scoring_func: str = "softmax",
6161
e_score_correction_bias: Optional[torch.Tensor] = None,
62-
global_num_experts: Optional[int] = None,
62+
global_num_experts: int = -1,
6363
expert_map: Optional[torch.Tensor] = None,
6464
apply_router_weight_on_input: bool = False,
6565
activation: str = "silu",
@@ -69,7 +69,6 @@ def forward_oot(
6969
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
7070

7171
topk_weights, topk_ids = select_experts(
72-
global_num_experts=global_num_experts,
7372
hidden_states=x,
7473
router_logits=router_logits,
7574
top_k=top_k,
@@ -80,7 +79,7 @@ def forward_oot(
8079
custom_routing_function=custom_routing_function,
8180
scoring_func=scoring_func,
8281
e_score_correction_bias=e_score_correction_bias,
83-
)
82+
global_num_experts=global_num_experts)
8483

8584
if topk_ids.shape[1] < top_k or is_310p():
8685
assert global_num_experts is not None

vllm_ascend/ops/fused_moe.py

Lines changed: 14 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
4747
from vllm_ascend.distributed.parallel_state import get_mc2_group
4848
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
49+
from vllm_ascend.ops.layers.experts_selector import select_experts
4950
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
5051
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
5152
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
@@ -920,143 +921,6 @@ def fused_experts(
920921
return final_hidden_states
921922

922923

923-
def native_grouped_topk(
924-
topk_weights: torch.Tensor,
925-
num_expert_group: Optional[int],
926-
topk_group: Optional[int],
927-
):
928-
topk_group = 0 if topk_group is None else topk_group
929-
num_expert_group = 0 if num_expert_group is None else num_expert_group
930-
931-
num_token = topk_weights.shape[0]
932-
grouped_weights = topk_weights.view(num_token, num_expert_group,
933-
-1).max(dim=-1).values
934-
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
935-
k=topk_group,
936-
dim=-1,
937-
sorted=False)[1]
938-
topk_group_mask = torch.zeros_like(grouped_weights)
939-
topk_group_mask.scatter_(1, topk_group_indices, 1)
940-
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
941-
num_token, num_expert_group,
942-
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
943-
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
944-
945-
return topk_weights
946-
947-
948-
def select_experts(
949-
hidden_states: torch.Tensor,
950-
router_logits: torch.Tensor,
951-
top_k: int,
952-
use_grouped_topk: bool,
953-
renormalize: bool,
954-
topk_group: Optional[int] = None,
955-
num_expert_group: Optional[int] = None,
956-
custom_routing_function: Optional[Callable] = None,
957-
scoring_func: str = "softmax",
958-
e_score_correction_bias: Optional[torch.Tensor] = None,
959-
global_num_experts: Optional[torch.Tensor] = None
960-
) -> tuple[torch.Tensor, torch.Tensor]:
961-
"""
962-
Select top-k experts based on router logits.
963-
964-
Args:
965-
hidden_states: Hidden states of shape (num_tokens, hidden_size).
966-
router_logits: Router logits of shape (num_tokens, num_experts).
967-
top_k: Number of experts to select.
968-
use_grouped_topk: Whether to group experts before selecting top-k.
969-
renormalize: Whether to renormalize the routing weights.
970-
topk_group: Number of expert groups to select from.
971-
num_expert_group: Number of experts in each group.
972-
custom_routing_function: Custom routing function.
973-
scoring_func: Scoring function to use.
974-
e_score_correction_bias: Correction bias to apply to expert scores.
975-
976-
Returns:
977-
topk_weights: Routing weights of shape (num_tokens, top_k).
978-
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
979-
980-
Raises:
981-
ValueError: If an unsupported scoring function is provided.
982-
"""
983-
984-
def _renormalize_topk_weights(
985-
topk_weights: torch.Tensor,
986-
renormalize: bool,
987-
):
988-
if renormalize:
989-
topk_weights = topk_weights / topk_weights.sum(dim=-1,
990-
keepdim=True)
991-
return topk_weights
992-
993-
if scoring_func == "softmax":
994-
# NOTE: vLLM use dtype=torch.float here
995-
if not use_grouped_topk and custom_routing_function is None:
996-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
997-
x=router_logits, finished=None, k=top_k)
998-
topk_ids = topk_ids.to(torch.int32)
999-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
1000-
return topk_weights, topk_ids
1001-
1002-
topk_weights = router_logits.softmax(dim=-1)
1003-
elif scoring_func == "sigmoid":
1004-
topk_weights = router_logits.sigmoid()
1005-
else:
1006-
raise ValueError(f"Unsupported scoring function: {scoring_func}")
1007-
1008-
if use_grouped_topk:
1009-
assert topk_group is not None
1010-
assert num_expert_group is not None
1011-
1012-
if e_score_correction_bias is not None:
1013-
# Store original scores before applying correction bias. We use biased
1014-
# scores for expert selection but original scores for routing weights
1015-
original_weights = topk_weights
1016-
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
1017-
1018-
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
1019-
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
1020-
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
1021-
topk_group)
1022-
# TODO bfloat16 is not supported in torch.topk with ge graph.
1023-
if e_score_correction_bias is not None:
1024-
topk_ids = torch.topk(topk_weights.to(torch.float32),
1025-
k=top_k,
1026-
dim=-1,
1027-
sorted=False)[1]
1028-
# Use original unbiased scores for the routing weights
1029-
topk_weights = original_weights.gather(1, topk_ids)
1030-
else:
1031-
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
1032-
k=top_k,
1033-
dim=-1,
1034-
sorted=False)
1035-
topk_ids = topk_ids.to(torch.int32)
1036-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
1037-
return topk_weights, topk_ids
1038-
1039-
if custom_routing_function is not None:
1040-
topk_weights, topk_ids = custom_routing_function(
1041-
hidden_states=hidden_states,
1042-
gating_output=router_logits,
1043-
topk=top_k,
1044-
renormalize=renormalize,
1045-
global_num_experts=global_num_experts)
1046-
# Required by npu_moe_init_routing
1047-
topk_ids = topk_ids.to(torch.int32)
1048-
return topk_weights, topk_ids
1049-
1050-
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
1051-
topk_weights = topk_weights.to(hidden_states.dtype)
1052-
1053-
# Required by npu_moe_init_routing
1054-
topk_ids = topk_ids.to(torch.int32)
1055-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
1056-
1057-
return topk_weights, topk_ids
1058-
1059-
1060924
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
1061925

1062926
def __init__(self, moe: FusedMoEConfig = None):
@@ -1111,36 +975,19 @@ def apply(
1111975
**kwargs,
1112976
) -> torch.Tensor:
1113977

1114-
is_deepseek_v3_r1 = global_num_experts == 256
1115-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
1116-
if is_deepseek_v3_r1:
1117-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
1118-
router_logits,
1119-
k=top_k, # topk currently is 8
1120-
bias=e_score_correction_bias,
1121-
k_group=topk_group, # fix: 4
1122-
group_count=num_expert_group, # fix 8
1123-
group_select_mode=
1124-
1, # 0: the maximum in the group; 1: topk2.sum(fix)
1125-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
1126-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
1127-
# out_flag=False, # todo new api; should the third output be output
1128-
# y2_flag=False, # old api; should the third output be output
1129-
routed_scaling_factor=1,
1130-
eps=float(1e-20))
1131-
else:
1132-
topk_weights, topk_ids = select_experts(
1133-
hidden_states=x,
1134-
router_logits=router_logits,
1135-
top_k=top_k,
1136-
use_grouped_topk=use_grouped_topk,
1137-
renormalize=renormalize,
1138-
topk_group=topk_group,
1139-
num_expert_group=num_expert_group,
1140-
custom_routing_function=custom_routing_function,
1141-
scoring_func=scoring_func,
1142-
e_score_correction_bias=e_score_correction_bias,
1143-
)
978+
topk_weights, topk_ids = select_experts(
979+
hidden_states=x,
980+
router_logits=router_logits,
981+
top_k=top_k,
982+
use_grouped_topk=use_grouped_topk,
983+
renormalize=renormalize,
984+
topk_group=topk_group,
985+
num_expert_group=num_expert_group,
986+
custom_routing_function=custom_routing_function,
987+
scoring_func=scoring_func,
988+
e_score_correction_bias=e_score_correction_bias,
989+
global_num_experts=global_num_experts,
990+
is_unquantized=True)
1144991

1145992
topk_weights = topk_weights.to(x.dtype)
1146993
# this is a naive implementation for experts load balance so as

vllm_ascend/ops/layers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)