@@ -531,15 +531,21 @@ def fused_experts_with_all2all(
531
531
return final_hidden_states
532
532
533
533
534
- def fused_experts_with_allgather (hidden_states : torch .Tensor ,
535
- w1 : torch .Tensor ,
536
- w1_scale : torch .Tensor ,
537
- w2 : torch .Tensor ,
538
- w2_scale : torch .Tensor ,
539
- topk_weights : torch .Tensor ,
540
- topk_ids : torch .Tensor ,
541
- top_k : int ,
542
- expert_map : torch .Tensor = None ):
534
+ def fused_experts_with_allgather (
535
+ hidden_states : torch .Tensor ,
536
+ w1 : torch .Tensor ,
537
+ w1_scale : torch .Tensor ,
538
+ w2 : torch .Tensor ,
539
+ w2_scale : torch .Tensor ,
540
+ topk_weights : torch .Tensor ,
541
+ topk_ids : torch .Tensor ,
542
+ top_k : int ,
543
+ expert_map : torch .Tensor = None ,
544
+ log2phy : torch .Tensor = None ,
545
+ global_redundant_expert_num : int = 0 ,
546
+ ) -> torch .Tensor :
547
+ if log2phy is not None :
548
+ topk_ids = log2phy [topk_ids ]
543
549
original_shape = hidden_states .shape
544
550
if len (original_shape ) == 3 :
545
551
hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
@@ -551,7 +557,7 @@ def fused_experts_with_allgather(hidden_states: torch.Tensor,
551
557
ep_rank = torch .distributed .get_rank (group = ep_group )
552
558
ep_size = torch .distributed .get_world_size (ep_group )
553
559
554
- global_num_experts = len (expert_map )
560
+ global_num_experts = len (expert_map ) + global_redundant_expert_num
555
561
local_num_experts = global_num_experts // ep_size
556
562
557
563
hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (hidden_states )
@@ -961,7 +967,9 @@ def apply(
961
967
topk_weights = topk_weights ,
962
968
topk_ids = topk_ids ,
963
969
top_k = top_k ,
964
- expert_map = expert_map )
970
+ expert_map = expert_map ,
971
+ log2phy = log2phy ,
972
+ global_redundant_expert_num = global_redundant_expert_num )
965
973
elif fused_moe_state == FusedMoEState .MC2 :
966
974
return fused_experts_with_mc2 (
967
975
hidden_states = x ,
0 commit comments