50
50
from vllm_ascend .ops .moe_dispatcher .token_dispatcher import (
51
51
MoEAlltoAllSeqOverLapDispatcher , MoEDispatcherConfig )
52
52
from vllm_ascend .ops .sequence_parallel import MetadataForPadding
53
- from vllm_ascend .torchair .utils import npu_stream_switch , npu_wait_tensor
54
53
from vllm_ascend .utils import (AscendSocVersion , dispose_tensor ,
55
54
get_all_reduce_merge_state ,
56
55
get_ascend_soc_version ,
@@ -76,8 +75,6 @@ def unified_fused_experts(
76
75
w1_scale_bias : torch .Tensor = None ,
77
76
w2_scale_bias : torch .Tensor = None ,
78
77
moe_comm_method : Optional [MoECommMethod ] = None ,
79
- # For TorchAir graph
80
- is_torchair : bool = False ,
81
78
# For Cube/Vector parallel
82
79
shared_experts : Optional [Any ] = None ,
83
80
quantized_x_for_share : Optional [Any ] = None ,
@@ -191,16 +188,14 @@ def fused_experts_with_mc2(
191
188
expert_map : torch .Tensor = None ,
192
189
moe_all_to_all_group_name : Optional [str ] = None ,
193
190
shared_experts : Optional [Any ] = None ,
194
- is_torchair : bool = False ,
195
191
mc2_mask : Optional [torch .Tensor ] = None ,
196
192
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
197
193
quant_mode = 0
198
194
ep_rank_id = moe_parallel_config .ep_rank
199
195
ep_world_size = moe_parallel_config .ep_size
200
196
201
197
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
202
- need_extra_args = (get_ascend_soc_version () == AscendSocVersion .A3
203
- or is_torchair )
198
+ need_extra_args = (get_ascend_soc_version () == AscendSocVersion .A3 )
204
199
205
200
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
206
201
a3_need_extra_args = get_ascend_soc_version () == AscendSocVersion .A3
@@ -246,11 +241,8 @@ def fused_experts_with_mc2(
246
241
0 :5 ]
247
242
248
243
if shared_experts is not None :
249
- with npu_stream_switch ("moe_secondary" , 0 ):
250
- npu_wait_tensor (hidden_states , topk_weights )
251
- shared_gate_up , _ = shared_experts .gate_up_proj (hidden_states )
252
- npu_wait_tensor (shared_gate_up , expand_x )
253
- shared_act = shared_experts .act_fn (shared_gate_up )
244
+ shared_gate_up , _ = shared_experts .gate_up_proj (hidden_states )
245
+ shared_act = shared_experts .act_fn (shared_gate_up )
254
246
255
247
w1 = w1 .transpose (1 , 2 )
256
248
@@ -324,9 +316,7 @@ def fused_experts_with_mc2(
324
316
if shared_experts is None :
325
317
return hidden_states
326
318
else :
327
- with npu_stream_switch ("moe_secondary" , 0 ):
328
- npu_wait_tensor (shared_act , down_out_list )
329
- shared_hidden_states , _ = shared_experts .down_proj (shared_act )
319
+ shared_hidden_states , _ = shared_experts .down_proj (shared_act )
330
320
return hidden_states , shared_hidden_states
331
321
332
322
@@ -930,9 +920,7 @@ def __init__(self, moe: FusedMoEConfig = None):
930
920
931
921
self .global_batch_size = vllm_config .scheduler_config .max_num_seqs
932
922
self .max_model_len = vllm_config .model_config .max_model_len
933
-
934
- ascend_config = get_ascend_config ()
935
- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
923
+ get_ascend_config ()
936
924
937
925
try :
938
926
device_group = get_mc2_group ().device_group
@@ -1169,10 +1157,6 @@ def __init__(
1169
1157
self .ep_size ,
1170
1158
get_ep_group ().rank_in_group , self .global_num_experts )
1171
1159
1172
- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
1173
- self .enable_multistream_moe = \
1174
- ascend_config .torchair_graph_config .enable_multistream_moe and \
1175
- self .torchair_graph_enabled
1176
1160
self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
1177
1161
1178
1162
if self .scoring_func != "softmax" and not self .use_grouped_topk :
@@ -1278,23 +1262,10 @@ def forward(self,
1278
1262
mc2_mask = forward_context .mc2_mask
1279
1263
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
1280
1264
quantized_x_for_share , dynamic_scale_for_share = None , None
1281
- from vllm_ascend .quantization .w8a8_dynamic import \
1282
- AscendW8A8DynamicFusedMoEMethod
1283
- if self .enable_multistream_moe :
1284
- if not self .rm_router_logits :
1285
- router_logits , _ = gate (hidden_states )
1286
- if hasattr (self .quant_method , "quant_method" ) and \
1287
- isinstance (self .quant_method .quant_method ,
1288
- AscendW8A8DynamicFusedMoEMethod
1289
- ) and fused_moe_state == FusedMoEState .MC2 :
1290
- with npu_stream_switch ("moe_secondary" , 0 ):
1291
- quantized_x_for_share , dynamic_scale_for_share = torch_npu .npu_dynamic_quant (
1292
- hidden_states )
1293
1265
1294
1266
if shared_experts :
1295
- if not self .enable_multistream_moe or fused_moe_state != FusedMoEState .MC2 :
1296
- # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
1297
- shared_hidden_states = shared_experts (hidden_states )
1267
+ # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
1268
+ shared_hidden_states = shared_experts (hidden_states )
1298
1269
1299
1270
mc2_mask = forward_context .mc2_mask
1300
1271
@@ -1339,16 +1310,15 @@ def forward(self,
1339
1310
if self .dp_size > 1 :
1340
1311
if fused_moe_state == FusedMoEState .AllGather :
1341
1312
# NOTE: When in torchair graph, it has been padded in model_runner_v1
1342
- if not self .torchair_graph_enabled :
1343
- max_tokens_across_dp = forward_context .max_tokens_across_dp
1344
- if num_tokens < max_tokens_across_dp :
1345
- hidden_states = nn .functional .pad (
1346
- hidden_states ,
1313
+ max_tokens_across_dp = forward_context .max_tokens_across_dp
1314
+ if num_tokens < max_tokens_across_dp :
1315
+ hidden_states = nn .functional .pad (
1316
+ hidden_states ,
1317
+ (0 , 0 , 0 , max_tokens_across_dp - num_tokens ))
1318
+ if not self .rm_router_logits :
1319
+ router_logits = nn .functional .pad (
1320
+ router_logits ,
1347
1321
(0 , 0 , 0 , max_tokens_across_dp - num_tokens ))
1348
- if not self .rm_router_logits :
1349
- router_logits = nn .functional .pad (
1350
- router_logits ,
1351
- (0 , 0 , 0 , max_tokens_across_dp - num_tokens ))
1352
1322
hidden_states = get_dp_group ().all_gather (hidden_states , 0 )
1353
1323
if self .rm_router_logits :
1354
1324
router_logits , _ = gate (hidden_states )
@@ -1385,8 +1355,7 @@ def forward(self,
1385
1355
enable_force_load_balance = enable_force_load_balance ,
1386
1356
log2phy = self .log2phy ,
1387
1357
global_redundant_expert_num = self .global_redundant_expert_num ,
1388
- shared_experts = shared_experts if self .torchair_graph_enabled
1389
- and self .enable_multistream_moe and not is_prefill else None ,
1358
+ shared_experts = None ,
1390
1359
mc2_mask = mc2_mask ,
1391
1360
token_dispatcher = self .token_dispatcher ,
1392
1361
quantized_x_for_share = quantized_x_for_share ,
0 commit comments