From e1192e2e03e060fda1d1942c1d0391423e8ee519 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Thu, 14 Aug 2025 11:28:01 +0800 Subject: [PATCH] support static EPLB for allgather EP Signed-off-by: realliujiaxu --- vllm_ascend/quantization/w8a8_dynamic.py | 30 +++++++++++++++--------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 38aad66fa0..4add71618d 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -531,15 +531,21 @@ def fused_experts_with_all2all( return final_hidden_states -def fused_experts_with_allgather(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None): +def fused_experts_with_allgather( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, +) -> torch.Tensor: + if log2phy is not None: + topk_ids = log2phy[topk_ids] original_shape = hidden_states.shape if len(original_shape) == 3: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -551,7 +557,7 @@ def fused_experts_with_allgather(hidden_states: torch.Tensor, ep_rank = torch.distributed.get_rank(group=ep_group) ep_size = torch.distributed.get_world_size(ep_group) - global_num_experts = len(expert_map) + global_num_experts = len(expert_map) + global_redundant_expert_num local_num_experts = global_num_experts // ep_size hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) @@ -961,7 +967,9 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, top_k=top_k, - expert_map=expert_map) + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num) elif fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x,