-
Notifications
You must be signed in to change notification settings - Fork 410
replace arange and permute with the third output of npu_moe_gating_top_k_softmax #2418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -759,6 +759,7 @@ def fused_experts( | |
expert_map: torch.Tensor = None, | ||
apply_router_weight_on_input: bool = False, | ||
max_num_tokens: Optional[int] = None, | ||
row_idx: torch.Tensor = None, | ||
) -> torch.Tensor: | ||
""" | ||
Fused experts with top-k routing. | ||
|
@@ -846,12 +847,13 @@ def fused_experts( | |
# Rearrange hidden_states | ||
sorted_hidden_states = hidden_states[sorted_token_indices] | ||
else: | ||
row_idx_len = num_tokens * top_k | ||
row_idx = (torch.arange(0, | ||
row_idx_len, | ||
dtype=torch.int32, | ||
device=device).view(top_k, -1).permute( | ||
1, 0).contiguous()) | ||
if row_idx is None: | ||
row_idx_len = num_tokens * top_k | ||
row_idx = (torch.arange(0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems all layers use the same |
||
row_idx_len, | ||
dtype=torch.int32, | ||
device=device).view(top_k, -1).permute( | ||
1, 0).contiguous()) | ||
active_num = max_num_tokens if max_num_tokens is not None else num_tokens | ||
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( | ||
hidden_states, | ||
|
@@ -975,7 +977,7 @@ def apply( | |
**kwargs, | ||
) -> torch.Tensor: | ||
|
||
topk_weights, topk_ids = select_experts( | ||
topk_weights, topk_ids, row_idx = select_experts( | ||
hidden_states=x, | ||
router_logits=router_logits, | ||
top_k=top_k, | ||
|
@@ -1020,7 +1022,8 @@ def apply( | |
topk_weights=topk_weights, | ||
topk_ids=topk_ids, | ||
top_k=top_k, | ||
expert_map=expert_map) | ||
expert_map=expert_map, | ||
row_idx=row_idx) | ||
elif MOE_ALL2ALL_BUFFER: | ||
return fused_experts_with_all2all_buffer( | ||
hidden_states=x, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,7 @@ def select_experts(hidden_states: torch.Tensor, | |
topk_ids: selected expert IDs of shape (num_tokens, top_k). | ||
""" | ||
|
||
topk_weights, topk_ids = _select_experts_with_fusion_ops( | ||
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops( | ||
router_logits=router_logits, | ||
top_k=top_k, | ||
use_grouped_topk=use_grouped_topk, | ||
|
@@ -83,7 +83,7 @@ def select_experts(hidden_states: torch.Tensor, | |
e_score_correction_bias=e_score_correction_bias, | ||
global_num_experts=global_num_experts, | ||
) | ||
return topk_weights, topk_ids | ||
return topk_weights, topk_ids, row_idx | ||
|
||
|
||
def _native_grouped_topk( | ||
|
@@ -168,7 +168,7 @@ def _select_experts_with_fusion_ops( | |
global_num_experts: int = -1, | ||
is_unquantized: bool = False): | ||
|
||
topk_weights, topk_ids = None, None | ||
topk_weights, topk_ids, row_idx = None, None, None | ||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern | ||
is_deepseek_v3_r1 = global_num_experts == 256 | ||
if is_deepseek_v3_r1: | ||
|
@@ -186,14 +186,13 @@ def _select_experts_with_fusion_ops( | |
# y2_flag=False, # old api; should the third output be output | ||
routed_scaling_factor=1, | ||
eps=float(1e-20)) | ||
|
||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized: | ||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( | ||
elif not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized: | ||
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
x=router_logits, finished=None, k=top_k) | ||
topk_ids = topk_ids.to(torch.int32) | ||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize) | ||
|
||
return topk_weights, topk_ids | ||
return topk_weights, topk_ids, row_idx | ||
|
||
|
||
def _native_select_experts( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could add a test case to cover the
fused_experts
function with therow_idx
parameter passed in.