Skip to content

Commit 9b07cda

Browse files
committed
[Qwen-moe] use npu_add_rms_norm_quant operator
Signed-off-by: s30076806 <[email protected]>
1 parent 0f7492d commit 9b07cda

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

vllm_ascend/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@
159159
# 1: enable moe all2all seq.
160160
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ":
161161
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))),
162+
# Whether to enable layernorm.py torch_npu.npu_add_rms_norm_quant
163+
# 0: default
164+
# 1: enable
165+
"USE_ADD_RMSNORM_QUANT":
166+
lambda: int(os.getenv("USE_ADD_RMSNORM_QUANT", '0')),
162167
}
163168

164169
# end-env-vars-definition

vllm_ascend/models/qwen3_moe.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@
5050
from vllm_ascend.ops.fused_moe import AscendFusedMoE
5151
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
5252
init_metadata_for_sp)
53+
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
54+
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
55+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
56+
import vllm_ascend.envs as envs
5357

5458

5559
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -183,8 +187,21 @@ def __init__(
183187
hidden_act=config.hidden_act,
184188
quant_config=quant_config,
185189
prefix=f"{prefix}.mlp")
186-
self.input_layernorm = RMSNorm(config.hidden_size,
187-
eps=config.rms_norm_eps)
190+
if not envs.USE_ADD_RMSNORM_QUANT:
191+
self.input_layernorm = RMSNorm(config.hidden_size,
192+
eps=config.rms_norm_eps)
193+
else:
194+
assert isinstance(quant_config, AscendQuantConfig), \
195+
"Expected quant_config to be an instance of AscendQuantConfig"
196+
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
197+
AscendW8A8LinearMethod):
198+
self.input_layernorm = AddRMSNormW8A8Quant(
199+
config.hidden_size,
200+
layer=self.self_attn.qkv_proj,
201+
eps=config.rms_norm_eps)
202+
else:
203+
self.input_layernorm = RMSNorm(config.hidden_size,
204+
eps=config.rms_norm_eps)
188205
self.post_attention_layernorm = RMSNorm(config.hidden_size,
189206
eps=config.rms_norm_eps)
190207

vllm_ascend/ops/layernorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
has_weight: bool = True,
3636
dtype: Optional[torch.dtype] = None,
3737
) -> None:
38-
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
38+
super().__init__(hidden_size)
3939
self.layer = layer
4040

4141
def forward(

0 commit comments

Comments
 (0)