Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def get_chip_type() -> str:
else:
raise ValueError(f"Unable to recognize chip name: {chip_name}, please manually set env SOC_VERSION")
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Get chip info failed: {e}")
logging.warning(f"Get chip info failed: {e}")
return ""
except FileNotFoundError:
logging.warning(
"npu-smi command not found, if this is an npu envir, please check if npu driver is installed correctly."
Expand Down
83 changes: 77 additions & 6 deletions vllm_ascend/device/device_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,40 @@ def npu_moe_init_routing(
quant_mode=quant_mode,
)

@staticmethod
def normalize_mxfp8_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
return scale

@staticmethod
def moe_gating_top_k(
x: torch.Tensor,
*,
k: int,
k_group: int,
group_count: int,
group_select_mode: int,
renorm: int,
norm_type: int,
out_flag: bool,
routed_scaling_factor: float = 1.0,
eps: float = 1e-20,
bias_opt: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, out = torch.ops._C_ascend.moe_gating_top_k(
x,
k=k,
k_group=k_group,
group_count=group_count,
group_select_mode=group_select_mode,
renorm=renorm,
norm_type=norm_type,
out_flag=out_flag,
routed_scaling_factor=routed_scaling_factor,
eps=eps,
bias_opt=bias_opt,
)
return topk_weights, topk_ids.to(torch.int32), out

@staticmethod
def npu_dynamic_quant(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -198,6 +232,46 @@ def npu_moe_init_routing(
quant_mode=quant_mode,
)

@staticmethod
def normalize_mxfp8_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
if scale is None or scale.ndim != 2:
return scale
if scale.shape[-1] % 2 != 0:
raise ValueError(f"Invalid MXFP8 scale shape: {tuple(scale.shape)}")
return scale.reshape(scale.shape[0], scale.shape[1] // 2, 2)

@staticmethod
def moe_gating_top_k(
x: torch.Tensor,
*,
k: int,
k_group: int,
group_count: int,
group_select_mode: int,
renorm: int,
norm_type: int,
out_flag: bool,
routed_scaling_factor: float = 1.0,
eps: float = 1e-20,
bias_opt: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, out = torch_npu.npu_moe_gating_top_k(
x,
k=k,
bias=bias_opt,
k_group=k_group,
group_count=group_count,
group_select_mode=group_select_mode,
renorm=0,
norm_type=norm_type,
routed_scaling_factor=routed_scaling_factor,
eps=eps,
)
if norm_type == 0 and renorm == 1:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

return topk_weights, topk_ids.to(torch.int32), out

@staticmethod
def npu_dynamic_quant(
hidden_states: torch.Tensor,
Expand All @@ -215,12 +289,9 @@ def npu_dynamic_quant(
)

if dynamic_scale is None:
return torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type)
hidden_states, dynamic_scale = torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type)

if dynamic_scale.ndim == 2:
dynamic_scale = dynamic_scale.reshape(dynamic_scale.shape[0], dynamic_scale.shape[1] // 2, 2)

return hidden_states, dynamic_scale
return hidden_states, A5DeviceAdaptor.normalize_mxfp8_scale_layout(dynamic_scale)

@staticmethod
def npu_grouped_matmul_swiglu_quant(
Expand Down Expand Up @@ -257,7 +328,7 @@ def npu_grouped_matmul_swiglu_quant(
weight_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
x_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
)
return out, out_scale, None
return out, A5DeviceAdaptor.normalize_mxfp8_scale_layout(out_scale), None

@staticmethod
def get_quant_gmm2_kwargs(
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/ops/fused_moe/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch

from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.utils import get_weight_prefetch_method


Expand Down Expand Up @@ -216,7 +217,7 @@ def _select_experts_with_fusion_ops(
norm_type = 0 if scoring_func == "softmax" else 1
if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
topk_weights, topk_ids, _ = DeviceOperator.moe_gating_top_k(
router_logits,
k=top_k,
k_group=topk_group,
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/fused_moe/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def quant_apply_mlp(
quantized_hidden_states = None
else:
unquantized_hidden_states = None
pertoken_scale = dynamic_scale
pertoken_scale = DeviceOperator.normalize_mxfp8_scale_layout(dynamic_scale) if use_mxfp_quant else dynamic_scale
quantized_hidden_states = hidden_states

bias1, bias2 = None, None
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/quantization/methods/w8a8_mxfp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ def apply(
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
random_matrix = torch.rand(
topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)

topk_weights = topk_weights.to(x.dtype)

Expand Down
Loading