Skip to content

Commit e1192e2

Browse files
committed
support static EPLB for allgather EP
Signed-off-by: realliujiaxu <[email protected]>
1 parent 103654c commit e1192e2

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -531,15 +531,21 @@ def fused_experts_with_all2all(
531531
return final_hidden_states
532532

533533

534-
def fused_experts_with_allgather(hidden_states: torch.Tensor,
535-
w1: torch.Tensor,
536-
w1_scale: torch.Tensor,
537-
w2: torch.Tensor,
538-
w2_scale: torch.Tensor,
539-
topk_weights: torch.Tensor,
540-
topk_ids: torch.Tensor,
541-
top_k: int,
542-
expert_map: torch.Tensor = None):
534+
def fused_experts_with_allgather(
535+
hidden_states: torch.Tensor,
536+
w1: torch.Tensor,
537+
w1_scale: torch.Tensor,
538+
w2: torch.Tensor,
539+
w2_scale: torch.Tensor,
540+
topk_weights: torch.Tensor,
541+
topk_ids: torch.Tensor,
542+
top_k: int,
543+
expert_map: torch.Tensor = None,
544+
log2phy: torch.Tensor = None,
545+
global_redundant_expert_num: int = 0,
546+
) -> torch.Tensor:
547+
if log2phy is not None:
548+
topk_ids = log2phy[topk_ids]
543549
original_shape = hidden_states.shape
544550
if len(original_shape) == 3:
545551
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
@@ -551,7 +557,7 @@ def fused_experts_with_allgather(hidden_states: torch.Tensor,
551557
ep_rank = torch.distributed.get_rank(group=ep_group)
552558
ep_size = torch.distributed.get_world_size(ep_group)
553559

554-
global_num_experts = len(expert_map)
560+
global_num_experts = len(expert_map) + global_redundant_expert_num
555561
local_num_experts = global_num_experts // ep_size
556562

557563
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
@@ -961,7 +967,9 @@ def apply(
961967
topk_weights=topk_weights,
962968
topk_ids=topk_ids,
963969
top_k=top_k,
964-
expert_map=expert_map)
970+
expert_map=expert_map,
971+
log2phy=log2phy,
972+
global_redundant_expert_num=global_redundant_expert_num)
965973
elif fused_moe_state == FusedMoEState.MC2:
966974
return fused_experts_with_mc2(
967975
hidden_states=x,

0 commit comments

Comments
 (0)