|
50 | 50 | from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
51 | 51 | from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
|
52 | 52 | init_metadata_for_sp)
|
53 |
| - |
| 53 | +from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant |
| 54 | +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod |
| 55 | +from vllm_ascend.quantization.quant_config import AscendQuantConfig |
54 | 56 |
|
55 | 57 | class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
56 | 58 |
|
@@ -183,8 +185,18 @@ def __init__(
|
183 | 185 | hidden_act=config.hidden_act,
|
184 | 186 | quant_config=quant_config,
|
185 | 187 | prefix=f"{prefix}.mlp")
|
186 |
| - self.input_layernorm = RMSNorm(config.hidden_size, |
187 |
| - eps=config.rms_norm_eps) |
| 188 | + |
| 189 | + assert isinstance(quant_config, AscendQuantConfig), \ |
| 190 | + "Expected quant_config to be an instance of AscendQuantConfig" |
| 191 | + if isinstance(self.self_attn.qkv_proj.quant_method.quant_method, |
| 192 | + AscendW8A8LinearMethod): |
| 193 | + self.input_layernorm = AddRMSNormW8A8Quant( |
| 194 | + config.hidden_size, |
| 195 | + layer=self.self_attn.qkv_proj, |
| 196 | + eps=config.rms_norm_eps) |
| 197 | + else: |
| 198 | + self.input_layernorm = RMSNorm(config.hidden_size, |
| 199 | + eps=config.rms_norm_eps) |
188 | 200 | self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
189 | 201 | eps=config.rms_norm_eps)
|
190 | 202 |
|
|
0 commit comments