@@ -66,21 +66,21 @@ def select_experts(hidden_states: torch.Tensor,
66
66
topk_ids: selected expert IDs of shape (num_tokens, top_k).
67
67
"""
68
68
69
- topk_weights , topk_ids , row_idx = _select_experts_with_fusion_ops (
70
- hidden_states = hidden_states ,
71
- router_logits = router_logits ,
72
- top_k = top_k ,
73
- use_grouped_topk = use_grouped_topk ,
74
- topk_group = topk_group ,
75
- renormalize = renormalize ,
76
- e_score_correction_bias = e_score_correction_bias ,
77
- num_expert_group = num_expert_group ,
78
- custom_routing_function = custom_routing_function ,
79
- scoring_func = scoring_func ,
80
- routed_scaling_factor = routed_scaling_factor ,
81
- global_num_experts = global_num_experts )
82
-
83
- if topk_weights is None :
69
+ topk_weights , topk_ids , row_idx = None , None , None
70
+ if custom_routing_function is None :
71
+ topk_weights , topk_ids = _select_experts_with_fusion_ops (
72
+ hidden_states = hidden_states ,
73
+ router_logits = router_logits ,
74
+ top_k = top_k ,
75
+ use_grouped_topk = use_grouped_topk ,
76
+ topk_group = topk_group ,
77
+ renormalize = renormalize ,
78
+ e_score_correction_bias = e_score_correction_bias ,
79
+ num_expert_group = num_expert_group ,
80
+ scoring_func = scoring_func ,
81
+ routed_scaling_factor = routed_scaling_factor ,
82
+ global_num_experts = global_num_experts )
83
+ else :
84
84
topk_weights , topk_ids = _native_select_experts (
85
85
hidden_states = hidden_states ,
86
86
router_logits = router_logits ,
@@ -94,8 +94,8 @@ def select_experts(hidden_states: torch.Tensor,
94
94
e_score_correction_bias = e_score_correction_bias ,
95
95
global_num_experts = global_num_experts ,
96
96
)
97
- if row_idx is None :
98
- row_idx = return_row_idx (hidden_states , top_k )
97
+ if row_idx is None :
98
+ row_idx = return_row_idx (hidden_states , top_k )
99
99
return topk_weights , topk_ids , row_idx
100
100
101
101
@@ -177,37 +177,36 @@ def _select_experts_with_fusion_ops(
177
177
e_score_correction_bias : Optional [torch .Tensor ],
178
178
topk_group : Optional [int ],
179
179
num_expert_group : Optional [int ],
180
- custom_routing_function : Optional [Callable ] = None ,
181
180
scoring_func : str = "softmax" ,
182
181
routed_scaling_factor = 1.0 ,
183
182
global_num_experts : int = - 1 ):
184
183
185
- topk_weights , topk_ids , row_idx = None , None , None
186
- # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
187
- is_deepseek_v3_r1 = global_num_experts == 256
188
- if is_deepseek_v3_r1 :
189
- topk_weights , topk_ids , _ = torch_npu . npu_moe_gating_top_k (
190
- router_logits ,
191
- k = top_k , # topk currently 8
192
- bias = e_score_correction_bias ,
193
- k_group = topk_group , # fix: 4
194
- group_count = num_expert_group , # fix 8
195
- group_select_mode =
196
- 1 , # 0: the maximum in the group; 1: topk2.sum(fix)
197
- renorm = 0 , # 0: softmax->topk(fix); 1: topk->softmax
198
- norm_type = 1 , # 0: softmax; 1: sigmoid(fix)
199
- # out_flag=False, # todo new api; should the third output be output
200
- # y2_flag=False, # old api; should the third output be output
201
- routed_scaling_factor = 1 ,
202
- eps = float ( 1e-20 ))
203
- row_idx = return_row_idx ( hidden_states , top_k )
204
- if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" :
205
- topk_weights , topk_ids , row_idx = torch_npu . npu_moe_gating_top_k_softmax (
206
- x = router_logits , finished = None , k = top_k )
207
- topk_ids = topk_ids . to ( torch . int32 )
184
+ if scoring_func == "softmax" :
185
+ norm_type = 0
186
+ topk_group = 1
187
+ num_expert_group = 1
188
+ else :
189
+ norm_type = 1
190
+ if e_score_correction_bias is not None and \
191
+ e_score_correction_bias . dtype != router_logits . dtype :
192
+ e_score_correction_bias = e_score_correction_bias . to ( router_logits . dtype )
193
+ topk_weights , topk_ids , _ = torch_npu . npu_moe_gating_top_k (
194
+ router_logits ,
195
+ k = top_k ,
196
+ bias = e_score_correction_bias ,
197
+ k_group = topk_group ,
198
+ group_count = num_expert_group ,
199
+ group_select_mode = 1 , # 0: the maximum in the group; 1: topk2.sum(fix)
200
+ renorm = 0 , # 0: softmax->topk(fix); 1: topk->softmax
201
+ norm_type = norm_type , # 0: softmax; 1: sigmoid
202
+ # out_flag=False, # todo new api; should the third output be output
203
+ # y2_flag=False, # old api; should the third output be output
204
+ routed_scaling_factor = 1 ,
205
+ eps = float ( 1e-20 ) )
206
+ if scoring_func == "softmax" :
208
207
topk_weights = _renormalize_topk_weights (topk_weights , renormalize )
209
208
210
- return topk_weights , topk_ids , row_idx
209
+ return topk_weights , topk_ids
211
210
212
211
213
212
def _native_select_experts (
@@ -281,3 +280,4 @@ def _native_select_experts(
281
280
topk_weights = _renormalize_topk_weights (topk_weights , renormalize )
282
281
283
282
return topk_weights , topk_ids
283
+
0 commit comments