@@ -141,7 +141,11 @@ def fused_experts_with_mc2(
141141 is_torchair : bool = False ,
142142 hidden_states_for_share : Optional [Any ] = None ,
143143 mc2_mask : Optional [torch .Tensor ] = None ,
144+ log2phy : Optional [torch .Tensor ] = None ,
145+ global_redundant_expert_num : int = 0
144146) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
147+ if log2phy is not None :
148+ topk_ids = log2phy [topk_ids ]
145149 quant_mode = 0
146150 ep_group = get_mc2_group ()
147151 ep_rank_id = ep_group .rank_in_group
@@ -163,7 +167,7 @@ def fused_experts_with_mc2(
163167
164168 enable_dispatch_v2 = hasattr (torch_npu , "npu_moe_distribute_dispatch_v2" )
165169
166- moe_expert_num = len (expert_map )
170+ moe_expert_num = len (expert_map ) + global_redundant_expert_num
167171 kwargs_mc2 = {
168172 "x" : hidden_states ,
169173 "expert_ids" : topk_ids ,
@@ -349,17 +353,16 @@ def apply_mlp(
349353
350354# currently expert parallelism implemented with all2all
351355# is under-optimized.
352- def fused_experts_with_all2all (
353- hidden_states : torch .Tensor ,
354- w1 : torch .Tensor ,
355- w2 : torch .Tensor ,
356- topk_weights : torch .Tensor ,
357- topk_ids : torch .Tensor ,
358- top_k : int ,
359- expert_map : torch .Tensor = None ,
360- ep_group : GroupCoordinator = None ,
361- max_num_tokens : Optional [int ] = None ,
362- ):
356+ def fused_experts_with_all2all (hidden_states : torch .Tensor ,
357+ w1 : torch .Tensor ,
358+ w2 : torch .Tensor ,
359+ topk_weights : torch .Tensor ,
360+ topk_ids : torch .Tensor ,
361+ top_k : int ,
362+ expert_map : torch .Tensor = None ,
363+ ep_group : GroupCoordinator = None ,
364+ max_num_tokens : Optional [int ] = None ,
365+ global_redundant_expert_num : int = 0 ):
363366 original_shape = hidden_states .shape
364367 if len (original_shape ) == 3 :
365368 hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
@@ -369,7 +372,7 @@ def fused_experts_with_all2all(
369372 device = hidden_states .device
370373
371374 if expert_map is not None :
372- global_num_experts = len (expert_map )
375+ global_num_experts = len (expert_map ) + global_redundant_expert_num
373376 local_num_experts = global_num_experts // ep_group .world_size
374377 row_idx_len = num_tokens * top_k
375378 row_idx = (torch .arange (0 ,
@@ -639,7 +642,10 @@ def fused_experts_with_all2allv(
639642 hidden_states : torch .Tensor ,
640643 w1 : torch .Tensor ,
641644 w2 : torch .Tensor ,
645+ log2phy : Optional [torch .Tensor ] = None ,
642646):
647+ if log2phy is not None :
648+ routing_map = log2phy [routing_map ]
643649 # Enable moe alltoallv, it's a balanced policy for precision and efficiency.
644650 (share_experts_output , dispatched_input ,
645651 tokens_per_expert ) = (token_dispatcher .token_permutation (
@@ -824,8 +830,8 @@ def fused_experts(
824830 expanded_src_to_dst_row = expanded_row_idx ,
825831 export_for_source_row = topk_ids ,
826832 )
827-
828- return final_hidden_states
833+ group_list_type = 0
834+ return final_hidden_states , expert_tokens , group_list_type
829835
830836
831837def native_grouped_topk (
@@ -1015,6 +1021,8 @@ def apply(
10151021 enable_force_load_balance : bool = False ,
10161022 hidden_states_for_share : Optional [Any ] = None ,
10171023 shared_experts : Optional [Any ] = None ,
1024+ log2phy : Optional [Any ] = None ,
1025+ global_redundant_expert_num : int = 0 ,
10181026 ** kwargs ,
10191027 ) -> torch .Tensor :
10201028
@@ -1071,6 +1079,8 @@ def apply(
10711079 is_torchair = self .torchair_graph_enabled ,
10721080 hidden_states_for_share = hidden_states_for_share ,
10731081 mc2_mask = mc2_mask ,
1082+ log2phy = log2phy ,
1083+ global_redundant_expert_num = global_redundant_expert_num ,
10741084 )
10751085 elif fused_moe_state == FusedMoEState .AllGather :
10761086 max_num_tokens = self .max_num_batched_tokens if self .use_aclgraph else None
@@ -1105,18 +1115,20 @@ def apply(
11051115 hidden_states = x ,
11061116 w1 = layer .w13_weight ,
11071117 w2 = layer .w2_weight ,
1108- )
1118+ log2phy = log2phy )
11091119 else :
11101120 max_num_tokens = self .max_num_batched_tokens if self .use_aclgraph else None
1111- return fused_experts_with_all2all (hidden_states = x ,
1112- w1 = layer .w13_weight ,
1113- w2 = layer .w2_weight ,
1114- topk_weights = topk_weights ,
1115- topk_ids = topk_ids ,
1116- top_k = top_k ,
1117- expert_map = expert_map ,
1118- ep_group = get_ep_group (),
1119- max_num_tokens = max_num_tokens )
1121+ return fused_experts_with_all2all (
1122+ hidden_states = x ,
1123+ w1 = layer .w13_weight ,
1124+ w2 = layer .w2_weight ,
1125+ topk_weights = topk_weights ,
1126+ topk_ids = topk_ids ,
1127+ top_k = top_k ,
1128+ expert_map = expert_map ,
1129+ ep_group = get_ep_group (),
1130+ max_num_tokens = max_num_tokens ,
1131+ global_redundant_expert_num = global_redundant_expert_num )
11201132
11211133
11221134class AscendFusedMoE (FusedMoE ):
@@ -1273,6 +1285,10 @@ def __init__(
12731285 if envs_ascend .VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance (
12741286 self .quant_method , AscendUnquantizedFusedMoEMethod ):
12751287 self .reduce_results = False
1288+ if expert_map_path and os .path .exists (expert_map_path ):
1289+ self .global_num_experts = self .global_num_experts + self .global_redundant_expert_num
1290+ self .local_num_experts = self .global_num_experts // self .ep_size
1291+
12761292 moe_dispatcher_config = (
12771293 MoEDispatcherConfig ().set_num_moe_experts (
12781294 self .global_num_experts ).set_num_local_experts (
0 commit comments