Skip to content

Commit 688bd57

Browse files
1092626063Liccol
authored andcommitted
refactor gatingtopk
1 parent c73dd8f commit 688bd57

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

vllm_ascend/ops/moe/experts_selector.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,21 @@ def select_experts(hidden_states: torch.Tensor,
6666
topk_ids: selected expert IDs of shape (num_tokens, top_k).
6767
"""
6868

69-
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
70-
hidden_states=hidden_states,
71-
router_logits=router_logits,
72-
top_k=top_k,
73-
use_grouped_topk=use_grouped_topk,
74-
topk_group=topk_group,
75-
renormalize=renormalize,
76-
e_score_correction_bias=e_score_correction_bias,
77-
num_expert_group=num_expert_group,
78-
custom_routing_function=custom_routing_function,
79-
scoring_func=scoring_func,
80-
routed_scaling_factor=routed_scaling_factor,
81-
global_num_experts=global_num_experts)
82-
83-
if topk_weights is None:
69+
topk_weights, topk_ids, row_idx = None, None, None
70+
if custom_routing_function is None:
71+
topk_weights, topk_ids = _select_experts_with_fusion_ops(
72+
hidden_states=hidden_states,
73+
router_logits=router_logits,
74+
top_k=top_k,
75+
use_grouped_topk=use_grouped_topk,
76+
topk_group=topk_group,
77+
renormalize=renormalize,
78+
e_score_correction_bias=e_score_correction_bias,
79+
num_expert_group=num_expert_group,
80+
scoring_func=scoring_func,
81+
routed_scaling_factor=routed_scaling_factor,
82+
global_num_experts=global_num_experts)
83+
else:
8484
topk_weights, topk_ids = _native_select_experts(
8585
hidden_states=hidden_states,
8686
router_logits=router_logits,
@@ -94,8 +94,8 @@ def select_experts(hidden_states: torch.Tensor,
9494
e_score_correction_bias=e_score_correction_bias,
9595
global_num_experts=global_num_experts,
9696
)
97-
if row_idx is None:
98-
row_idx = return_row_idx(hidden_states, top_k)
97+
if row_idx is None:
98+
row_idx = return_row_idx(hidden_states, top_k)
9999
return topk_weights, topk_ids, row_idx
100100

101101

@@ -177,37 +177,36 @@ def _select_experts_with_fusion_ops(
177177
e_score_correction_bias: Optional[torch.Tensor],
178178
topk_group: Optional[int],
179179
num_expert_group: Optional[int],
180-
custom_routing_function: Optional[Callable] = None,
181180
scoring_func: str = "softmax",
182181
routed_scaling_factor=1.0,
183182
global_num_experts: int = -1):
184183

185-
topk_weights, topk_ids, row_idx = None, None, None
186-
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
187-
is_deepseek_v3_r1 = global_num_experts == 256
188-
if is_deepseek_v3_r1:
189-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
190-
router_logits,
191-
k=top_k, # topk currently 8
192-
bias=e_score_correction_bias,
193-
k_group=topk_group, # fix: 4
194-
group_count=num_expert_group, # fix 8
195-
group_select_mode=
196-
1, # 0: the maximum in the group; 1: topk2.sum(fix)
197-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
198-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
199-
# out_flag=False, # todo new api; should the third output be output
200-
# y2_flag=False, # old api; should the third output be output
201-
routed_scaling_factor=1,
202-
eps=float(1e-20))
203-
row_idx = return_row_idx(hidden_states, top_k)
204-
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
205-
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
206-
x=router_logits, finished=None, k=top_k)
207-
topk_ids = topk_ids.to(torch.int32)
184+
if scoring_func == "softmax":
185+
norm_type = 0
186+
topk_group = 1
187+
num_expert_group = 1
188+
else:
189+
norm_type = 1
190+
if e_score_correction_bias is not None and \
191+
e_score_correction_bias.dtype != router_logits.dtype:
192+
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
193+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
194+
router_logits,
195+
k=top_k,
196+
bias=e_score_correction_bias,
197+
k_group=topk_group,
198+
group_count=num_expert_group,
199+
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
200+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
201+
norm_type=norm_type, # 0: softmax; 1: sigmoid
202+
# out_flag=False, # todo new api; should the third output be output
203+
# y2_flag=False, # old api; should the third output be output
204+
routed_scaling_factor=1,
205+
eps=float(1e-20))
206+
if scoring_func == "softmax":
208207
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
209208

210-
return topk_weights, topk_ids, row_idx
209+
return topk_weights, topk_ids
211210

212211

213212
def _native_select_experts(
@@ -281,3 +280,4 @@ def _native_select_experts(
281280
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
282281

283282
return topk_weights, topk_ids
283+

0 commit comments

Comments
 (0)