Skip to content

Commit 5d8ec28

Browse files
authored
[2/N][refactor] split torchair from fused_moe (#2503)
### What this PR does / why we need it? After moved torchair related fused_moe section into torchair_fused_moe, split the torchair from the origin fused_moe ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? vLLM version: main vLLM main: vllm-project/vllm@ab9f2cf - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@2a97ffc Signed-off-by: hust17yixuan <[email protected]>
1 parent cfe77e8 commit 5d8ec28

File tree

1 file changed

+16
-47
lines changed

1 file changed

+16
-47
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 16 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
5151
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
5252
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
53-
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
5453
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5554
get_all_reduce_merge_state,
5655
get_ascend_soc_version,
@@ -76,8 +75,6 @@ def unified_fused_experts(
7675
w1_scale_bias: torch.Tensor = None,
7776
w2_scale_bias: torch.Tensor = None,
7877
moe_comm_method: Optional[MoECommMethod] = None,
79-
# For TorchAir graph
80-
is_torchair: bool = False,
8178
# For Cube/Vector parallel
8279
shared_experts: Optional[Any] = None,
8380
quantized_x_for_share: Optional[Any] = None,
@@ -191,16 +188,14 @@ def fused_experts_with_mc2(
191188
expert_map: torch.Tensor = None,
192189
moe_all_to_all_group_name: Optional[str] = None,
193190
shared_experts: Optional[Any] = None,
194-
is_torchair: bool = False,
195191
mc2_mask: Optional[torch.Tensor] = None,
196192
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
197193
quant_mode = 0
198194
ep_rank_id = moe_parallel_config.ep_rank
199195
ep_world_size = moe_parallel_config.ep_size
200196

201197
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
202-
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
203-
or is_torchair)
198+
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3)
204199

205200
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
206201
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
@@ -246,11 +241,8 @@ def fused_experts_with_mc2(
246241
0:5]
247242

248243
if shared_experts is not None:
249-
with npu_stream_switch("moe_secondary", 0):
250-
npu_wait_tensor(hidden_states, topk_weights)
251-
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
252-
npu_wait_tensor(shared_gate_up, expand_x)
253-
shared_act = shared_experts.act_fn(shared_gate_up)
244+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
245+
shared_act = shared_experts.act_fn(shared_gate_up)
254246

255247
w1 = w1.transpose(1, 2)
256248

@@ -324,9 +316,7 @@ def fused_experts_with_mc2(
324316
if shared_experts is None:
325317
return hidden_states
326318
else:
327-
with npu_stream_switch("moe_secondary", 0):
328-
npu_wait_tensor(shared_act, down_out_list)
329-
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
319+
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
330320
return hidden_states, shared_hidden_states
331321

332322

@@ -930,9 +920,7 @@ def __init__(self, moe: FusedMoEConfig = None):
930920

931921
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
932922
self.max_model_len = vllm_config.model_config.max_model_len
933-
934-
ascend_config = get_ascend_config()
935-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
923+
get_ascend_config()
936924

937925
try:
938926
device_group = get_mc2_group().device_group
@@ -1169,10 +1157,6 @@ def __init__(
11691157
self.ep_size,
11701158
get_ep_group().rank_in_group, self.global_num_experts)
11711159

1172-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
1173-
self.enable_multistream_moe = \
1174-
ascend_config.torchair_graph_config.enable_multistream_moe and \
1175-
self.torchair_graph_enabled
11761160
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
11771161

11781162
if self.scoring_func != "softmax" and not self.use_grouped_topk:
@@ -1278,23 +1262,10 @@ def forward(self,
12781262
mc2_mask = forward_context.mc2_mask
12791263
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
12801264
quantized_x_for_share, dynamic_scale_for_share = None, None
1281-
from vllm_ascend.quantization.w8a8_dynamic import \
1282-
AscendW8A8DynamicFusedMoEMethod
1283-
if self.enable_multistream_moe:
1284-
if not self.rm_router_logits:
1285-
router_logits, _ = gate(hidden_states)
1286-
if hasattr(self.quant_method, "quant_method") and \
1287-
isinstance(self.quant_method.quant_method,
1288-
AscendW8A8DynamicFusedMoEMethod
1289-
) and fused_moe_state == FusedMoEState.MC2:
1290-
with npu_stream_switch("moe_secondary", 0):
1291-
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1292-
hidden_states)
12931265

12941266
if shared_experts:
1295-
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
1296-
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
1297-
shared_hidden_states = shared_experts(hidden_states)
1267+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
1268+
shared_hidden_states = shared_experts(hidden_states)
12981269

12991270
mc2_mask = forward_context.mc2_mask
13001271

@@ -1339,16 +1310,15 @@ def forward(self,
13391310
if self.dp_size > 1:
13401311
if fused_moe_state == FusedMoEState.AllGather:
13411312
# NOTE: When in torchair graph, it has been padded in model_runner_v1
1342-
if not self.torchair_graph_enabled:
1343-
max_tokens_across_dp = forward_context.max_tokens_across_dp
1344-
if num_tokens < max_tokens_across_dp:
1345-
hidden_states = nn.functional.pad(
1346-
hidden_states,
1313+
max_tokens_across_dp = forward_context.max_tokens_across_dp
1314+
if num_tokens < max_tokens_across_dp:
1315+
hidden_states = nn.functional.pad(
1316+
hidden_states,
1317+
(0, 0, 0, max_tokens_across_dp - num_tokens))
1318+
if not self.rm_router_logits:
1319+
router_logits = nn.functional.pad(
1320+
router_logits,
13471321
(0, 0, 0, max_tokens_across_dp - num_tokens))
1348-
if not self.rm_router_logits:
1349-
router_logits = nn.functional.pad(
1350-
router_logits,
1351-
(0, 0, 0, max_tokens_across_dp - num_tokens))
13521322
hidden_states = get_dp_group().all_gather(hidden_states, 0)
13531323
if self.rm_router_logits:
13541324
router_logits, _ = gate(hidden_states)
@@ -1385,8 +1355,7 @@ def forward(self,
13851355
enable_force_load_balance=enable_force_load_balance,
13861356
log2phy=self.log2phy,
13871357
global_redundant_expert_num=self.global_redundant_expert_num,
1388-
shared_experts=shared_experts if self.torchair_graph_enabled
1389-
and self.enable_multistream_moe and not is_prefill else None,
1358+
shared_experts=None,
13901359
mc2_mask=mc2_mask,
13911360
token_dispatcher=self.token_dispatcher,
13921361
quantized_x_for_share=quantized_x_for_share,

0 commit comments

Comments
 (0)