diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 29ab675525..fd5f2a5dc1 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -50,7 +50,9 @@ from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, init_metadata_for_sp) - +from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod +from vllm_ascend.quantization.quant_config import AscendQuantConfig class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -183,8 +185,18 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp") - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + + assert isinstance(quant_config, AscendQuantConfig), \ + "Expected quant_config to be an instance of AscendQuantConfig" + if isinstance(self.self_attn.qkv_proj.quant_method.quant_method, + AscendW8A8LinearMethod): + self.input_layernorm = AddRMSNormW8A8Quant( + config.hidden_size, + layer=self.self_attn.qkv_proj, + eps=config.rms_norm_eps) + else: + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)