|
50 | 50 | from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
51 | 51 | from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
|
52 | 52 | init_metadata_for_sp)
|
| 53 | +from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant |
| 54 | +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod |
| 55 | +from vllm_ascend.quantization.quant_config import AscendQuantConfig |
| 56 | +import vllm_ascend.envs as envs |
53 | 57 |
|
54 | 58 |
|
55 | 59 | class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
@@ -183,8 +187,21 @@ def __init__(
|
183 | 187 | hidden_act=config.hidden_act,
|
184 | 188 | quant_config=quant_config,
|
185 | 189 | prefix=f"{prefix}.mlp")
|
186 |
| - self.input_layernorm = RMSNorm(config.hidden_size, |
187 |
| - eps=config.rms_norm_eps) |
| 190 | + if not envs.USE_ADD_RMSNORM_QUANT: |
| 191 | + self.input_layernorm = RMSNorm(config.hidden_size, |
| 192 | + eps=config.rms_norm_eps) |
| 193 | + else: |
| 194 | + assert isinstance(quant_config, AscendQuantConfig), \ |
| 195 | + "Expected quant_config to be an instance of AscendQuantConfig" |
| 196 | + if isinstance(self.self_attn.qkv_proj.quant_method.quant_method, |
| 197 | + AscendW8A8LinearMethod): |
| 198 | + self.input_layernorm = AddRMSNormW8A8Quant( |
| 199 | + config.hidden_size, |
| 200 | + layer=self.self_attn.qkv_proj, |
| 201 | + eps=config.rms_norm_eps) |
| 202 | + else: |
| 203 | + self.input_layernorm = RMSNorm(config.hidden_size, |
| 204 | + eps=config.rms_norm_eps) |
188 | 205 | self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
189 | 206 | eps=config.rms_norm_eps)
|
190 | 207 |
|
|
0 commit comments