|
46 | 46 | from vllm_ascend.distributed.moe_comm_method import MoECommMethod
|
47 | 47 | from vllm_ascend.distributed.parallel_state import get_mc2_group
|
48 | 48 | from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
| 49 | +from vllm_ascend.ops.layers.experts_selector import select_experts |
49 | 50 | from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
50 | 51 | MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
51 | 52 | from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
@@ -920,143 +921,6 @@ def fused_experts(
|
920 | 921 | return final_hidden_states
|
921 | 922 |
|
922 | 923 |
|
923 |
| -def native_grouped_topk( |
924 |
| - topk_weights: torch.Tensor, |
925 |
| - num_expert_group: Optional[int], |
926 |
| - topk_group: Optional[int], |
927 |
| -): |
928 |
| - topk_group = 0 if topk_group is None else topk_group |
929 |
| - num_expert_group = 0 if num_expert_group is None else num_expert_group |
930 |
| - |
931 |
| - num_token = topk_weights.shape[0] |
932 |
| - grouped_weights = topk_weights.view(num_token, num_expert_group, |
933 |
| - -1).max(dim=-1).values |
934 |
| - topk_group_indices = torch.topk(grouped_weights.to(torch.float32), |
935 |
| - k=topk_group, |
936 |
| - dim=-1, |
937 |
| - sorted=False)[1] |
938 |
| - topk_group_mask = torch.zeros_like(grouped_weights) |
939 |
| - topk_group_mask.scatter_(1, topk_group_indices, 1) |
940 |
| - topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( |
941 |
| - num_token, num_expert_group, |
942 |
| - topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) |
943 |
| - topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) |
944 |
| - |
945 |
| - return topk_weights |
946 |
| - |
947 |
| - |
948 |
| -def select_experts( |
949 |
| - hidden_states: torch.Tensor, |
950 |
| - router_logits: torch.Tensor, |
951 |
| - top_k: int, |
952 |
| - use_grouped_topk: bool, |
953 |
| - renormalize: bool, |
954 |
| - topk_group: Optional[int] = None, |
955 |
| - num_expert_group: Optional[int] = None, |
956 |
| - custom_routing_function: Optional[Callable] = None, |
957 |
| - scoring_func: str = "softmax", |
958 |
| - e_score_correction_bias: Optional[torch.Tensor] = None, |
959 |
| - global_num_experts: Optional[torch.Tensor] = None |
960 |
| -) -> tuple[torch.Tensor, torch.Tensor]: |
961 |
| - """ |
962 |
| - Select top-k experts based on router logits. |
963 |
| -
|
964 |
| - Args: |
965 |
| - hidden_states: Hidden states of shape (num_tokens, hidden_size). |
966 |
| - router_logits: Router logits of shape (num_tokens, num_experts). |
967 |
| - top_k: Number of experts to select. |
968 |
| - use_grouped_topk: Whether to group experts before selecting top-k. |
969 |
| - renormalize: Whether to renormalize the routing weights. |
970 |
| - topk_group: Number of expert groups to select from. |
971 |
| - num_expert_group: Number of experts in each group. |
972 |
| - custom_routing_function: Custom routing function. |
973 |
| - scoring_func: Scoring function to use. |
974 |
| - e_score_correction_bias: Correction bias to apply to expert scores. |
975 |
| -
|
976 |
| - Returns: |
977 |
| - topk_weights: Routing weights of shape (num_tokens, top_k). |
978 |
| - topk_ids: Selected expert IDs of shape (num_tokens, top_k). |
979 |
| -
|
980 |
| - Raises: |
981 |
| - ValueError: If an unsupported scoring function is provided. |
982 |
| - """ |
983 |
| - |
984 |
| - def _renormalize_topk_weights( |
985 |
| - topk_weights: torch.Tensor, |
986 |
| - renormalize: bool, |
987 |
| - ): |
988 |
| - if renormalize: |
989 |
| - topk_weights = topk_weights / topk_weights.sum(dim=-1, |
990 |
| - keepdim=True) |
991 |
| - return topk_weights |
992 |
| - |
993 |
| - if scoring_func == "softmax": |
994 |
| - # NOTE: vLLM use dtype=torch.float here |
995 |
| - if not use_grouped_topk and custom_routing_function is None: |
996 |
| - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( |
997 |
| - x=router_logits, finished=None, k=top_k) |
998 |
| - topk_ids = topk_ids.to(torch.int32) |
999 |
| - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) |
1000 |
| - return topk_weights, topk_ids |
1001 |
| - |
1002 |
| - topk_weights = router_logits.softmax(dim=-1) |
1003 |
| - elif scoring_func == "sigmoid": |
1004 |
| - topk_weights = router_logits.sigmoid() |
1005 |
| - else: |
1006 |
| - raise ValueError(f"Unsupported scoring function: {scoring_func}") |
1007 |
| - |
1008 |
| - if use_grouped_topk: |
1009 |
| - assert topk_group is not None |
1010 |
| - assert num_expert_group is not None |
1011 |
| - |
1012 |
| - if e_score_correction_bias is not None: |
1013 |
| - # Store original scores before applying correction bias. We use biased |
1014 |
| - # scores for expert selection but original scores for routing weights |
1015 |
| - original_weights = topk_weights |
1016 |
| - topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) |
1017 |
| - |
1018 |
| - # TODO: Change to npu_group_topk when the latest CANN and NNAL is available |
1019 |
| - # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) |
1020 |
| - topk_weights = native_grouped_topk(topk_weights, num_expert_group, |
1021 |
| - topk_group) |
1022 |
| - # TODO bfloat16 is not supported in torch.topk with ge graph. |
1023 |
| - if e_score_correction_bias is not None: |
1024 |
| - topk_ids = torch.topk(topk_weights.to(torch.float32), |
1025 |
| - k=top_k, |
1026 |
| - dim=-1, |
1027 |
| - sorted=False)[1] |
1028 |
| - # Use original unbiased scores for the routing weights |
1029 |
| - topk_weights = original_weights.gather(1, topk_ids) |
1030 |
| - else: |
1031 |
| - topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), |
1032 |
| - k=top_k, |
1033 |
| - dim=-1, |
1034 |
| - sorted=False) |
1035 |
| - topk_ids = topk_ids.to(torch.int32) |
1036 |
| - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) |
1037 |
| - return topk_weights, topk_ids |
1038 |
| - |
1039 |
| - if custom_routing_function is not None: |
1040 |
| - topk_weights, topk_ids = custom_routing_function( |
1041 |
| - hidden_states=hidden_states, |
1042 |
| - gating_output=router_logits, |
1043 |
| - topk=top_k, |
1044 |
| - renormalize=renormalize, |
1045 |
| - global_num_experts=global_num_experts) |
1046 |
| - # Required by npu_moe_init_routing |
1047 |
| - topk_ids = topk_ids.to(torch.int32) |
1048 |
| - return topk_weights, topk_ids |
1049 |
| - |
1050 |
| - topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) |
1051 |
| - topk_weights = topk_weights.to(hidden_states.dtype) |
1052 |
| - |
1053 |
| - # Required by npu_moe_init_routing |
1054 |
| - topk_ids = topk_ids.to(torch.int32) |
1055 |
| - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) |
1056 |
| - |
1057 |
| - return topk_weights, topk_ids |
1058 |
| - |
1059 |
| - |
1060 | 924 | class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
1061 | 925 |
|
1062 | 926 | def __init__(self, moe: FusedMoEConfig = None):
|
@@ -1111,36 +975,19 @@ def apply(
|
1111 | 975 | **kwargs,
|
1112 | 976 | ) -> torch.Tensor:
|
1113 | 977 |
|
1114 |
| - is_deepseek_v3_r1 = global_num_experts == 256 |
1115 |
| - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern |
1116 |
| - if is_deepseek_v3_r1: |
1117 |
| - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( |
1118 |
| - router_logits, |
1119 |
| - k=top_k, # topk currently is 8 |
1120 |
| - bias=e_score_correction_bias, |
1121 |
| - k_group=topk_group, # fix: 4 |
1122 |
| - group_count=num_expert_group, # fix 8 |
1123 |
| - group_select_mode= |
1124 |
| - 1, # 0: the maximum in the group; 1: topk2.sum(fix) |
1125 |
| - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax |
1126 |
| - norm_type=1, # 0: softmax; 1: sigmoid(fix) |
1127 |
| - # out_flag=False, # todo new api; should the third output be output |
1128 |
| - # y2_flag=False, # old api; should the third output be output |
1129 |
| - routed_scaling_factor=1, |
1130 |
| - eps=float(1e-20)) |
1131 |
| - else: |
1132 |
| - topk_weights, topk_ids = select_experts( |
1133 |
| - hidden_states=x, |
1134 |
| - router_logits=router_logits, |
1135 |
| - top_k=top_k, |
1136 |
| - use_grouped_topk=use_grouped_topk, |
1137 |
| - renormalize=renormalize, |
1138 |
| - topk_group=topk_group, |
1139 |
| - num_expert_group=num_expert_group, |
1140 |
| - custom_routing_function=custom_routing_function, |
1141 |
| - scoring_func=scoring_func, |
1142 |
| - e_score_correction_bias=e_score_correction_bias, |
1143 |
| - ) |
| 978 | + topk_weights, topk_ids = select_experts( |
| 979 | + hidden_states=x, |
| 980 | + router_logits=router_logits, |
| 981 | + top_k=top_k, |
| 982 | + use_grouped_topk=use_grouped_topk, |
| 983 | + renormalize=renormalize, |
| 984 | + topk_group=topk_group, |
| 985 | + num_expert_group=num_expert_group, |
| 986 | + custom_routing_function=custom_routing_function, |
| 987 | + scoring_func=scoring_func, |
| 988 | + e_score_correction_bias=e_score_correction_bias, |
| 989 | + global_num_experts=global_num_experts, |
| 990 | + is_unquantized=True) |
1144 | 991 |
|
1145 | 992 | topk_weights = topk_weights.to(x.dtype)
|
1146 | 993 | # this is a naive implementation for experts load balance so as
|
|
0 commit comments