Skip to content

Commit ae55649

Browse files
committed
support static EPLB for allgather EP
Signed-off-by: realliujiaxu <[email protected]>
1 parent 103654c commit ae55649

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,12 @@ def fused_experts_with_allgather(hidden_states: torch.Tensor,
539539
topk_weights: torch.Tensor,
540540
topk_ids: torch.Tensor,
541541
top_k: int,
542-
expert_map: torch.Tensor = None):
542+
expert_map: torch.Tensor = None,
543+
log2phy: torch.Tensor = None,
544+
global_redundant_expert_num: int = 0,
545+
) -> torch.Tensor:
546+
if log2phy is not None:
547+
topk_ids = log2phy[topk_ids]
543548
original_shape = hidden_states.shape
544549
if len(original_shape) == 3:
545550
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
@@ -551,7 +556,7 @@ def fused_experts_with_allgather(hidden_states: torch.Tensor,
551556
ep_rank = torch.distributed.get_rank(group=ep_group)
552557
ep_size = torch.distributed.get_world_size(ep_group)
553558

554-
global_num_experts = len(expert_map)
559+
global_num_experts = len(expert_map) + global_redundant_expert_num
555560
local_num_experts = global_num_experts // ep_size
556561

557562
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
@@ -961,7 +966,9 @@ def apply(
961966
topk_weights=topk_weights,
962967
topk_ids=topk_ids,
963968
top_k=top_k,
964-
expert_map=expert_map)
969+
expert_map=expert_map,
970+
log2phy=log2phy,
971+
global_redundant_expert_num=global_redundant_expert_num)
965972
elif fused_moe_state == FusedMoEState.MC2:
966973
return fused_experts_with_mc2(
967974
hidden_states=x,

0 commit comments

Comments
 (0)