Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,15 +531,21 @@ def fused_experts_with_all2all(
return final_hidden_states


def fused_experts_with_allgather(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None):
def fused_experts_with_allgather(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
) -> torch.Tensor:
if log2phy is not None:
topk_ids = log2phy[topk_ids]
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
Expand All @@ -551,7 +557,7 @@ def fused_experts_with_allgather(hidden_states: torch.Tensor,
ep_rank = torch.distributed.get_rank(group=ep_group)
ep_size = torch.distributed.get_world_size(ep_group)

global_num_experts = len(expert_map)
global_num_experts = len(expert_map) + global_redundant_expert_num
local_num_experts = global_num_experts // ep_size

hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
Expand Down Expand Up @@ -961,7 +967,9 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num)
elif fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
Expand Down
Loading