Skip to content

Commit a708ebc

Browse files
committed
[Qwen-moe] use npu_add_rms_norm_quant operator
Signed-off-by: s30076806 <[email protected]>
1 parent 0f7492d commit a708ebc

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

vllm_ascend/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@
159159
# 1: enable moe all2all seq.
160160
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ":
161161
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))),
162+
# Whether to enable layernorm.py torch_npu.npu_add_rms_norm_quant
163+
# 0: default
164+
# 1: enable
165+
"USE_ADD_RMSNORM_QUANT":
166+
lambda: int(os.getenv("USE_ADD_RMSNORM_QUANT", '0')),
162167
}
163168

164169
# end-env-vars-definition

vllm_ascend/models/qwen3_moe.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@
5050
from vllm_ascend.ops.fused_moe import AscendFusedMoE
5151
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
5252
init_metadata_for_sp)
53+
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
54+
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
55+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
56+
import vllm_ascend.envs as envs
5357

5458

5559
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -183,8 +187,21 @@ def __init__(
183187
hidden_act=config.hidden_act,
184188
quant_config=quant_config,
185189
prefix=f"{prefix}.mlp")
186-
self.input_layernorm = RMSNorm(config.hidden_size,
187-
eps=config.rms_norm_eps)
190+
if not envs.USE_ADD_RMSNORM_QUANT:
191+
self.input_layernorm = RMSNorm(config.hidden_size,
192+
eps=config.rms_norm_eps)
193+
else:
194+
assert isinstance(quant_config, AscendQuantConfig), \
195+
"Expected quant_config to be an instance of AscendQuantConfig"
196+
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
197+
AscendW8A8LinearMethod):
198+
self.input_layernorm = AddRMSNormW8A8Quant(
199+
config.hidden_size,
200+
layer=self.self_attn.qkv_proj,
201+
eps=config.rms_norm_eps)
202+
else:
203+
self.input_layernorm = RMSNorm(config.hidden_size,
204+
eps=config.rms_norm_eps)
188205
self.post_attention_layernorm = RMSNorm(config.hidden_size,
189206
eps=config.rms_norm_eps)
190207

0 commit comments

Comments
 (0)