@@ -539,7 +539,12 @@ def fused_experts_with_allgather(hidden_states: torch.Tensor,
539
539
topk_weights : torch .Tensor ,
540
540
topk_ids : torch .Tensor ,
541
541
top_k : int ,
542
- expert_map : torch .Tensor = None ):
542
+ expert_map : torch .Tensor = None ,
543
+ log2phy : torch .Tensor = None ,
544
+ global_redundant_expert_num : int = 0 ,
545
+ ) -> torch .Tensor :
546
+ if log2phy is not None :
547
+ topk_ids = log2phy [topk_ids ]
543
548
original_shape = hidden_states .shape
544
549
if len (original_shape ) == 3 :
545
550
hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
@@ -551,7 +556,7 @@ def fused_experts_with_allgather(hidden_states: torch.Tensor,
551
556
ep_rank = torch .distributed .get_rank (group = ep_group )
552
557
ep_size = torch .distributed .get_world_size (ep_group )
553
558
554
- global_num_experts = len (expert_map )
559
+ global_num_experts = len (expert_map ) + global_redundant_expert_num
555
560
local_num_experts = global_num_experts // ep_size
556
561
557
562
hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (hidden_states )
@@ -961,7 +966,9 @@ def apply(
961
966
topk_weights = topk_weights ,
962
967
topk_ids = topk_ids ,
963
968
top_k = top_k ,
964
- expert_map = expert_map )
969
+ expert_map = expert_map ,
970
+ log2phy = log2phy ,
971
+ global_redundant_expert_num = global_redundant_expert_num )
965
972
elif fused_moe_state == FusedMoEState .MC2 :
966
973
return fused_experts_with_mc2 (
967
974
hidden_states = x ,
0 commit comments