Skip to content

Commit 87a0b7b

Browse files
huyqhuyuanquan1
andauthored
[bugfix] adapt bugfix for norm_quant_fusion_pass to npugraph_ex (#6726)
### What this PR does / why we need it? This PR adapts bugfixes from `norm_quant_fusion_pass` to `graphex_norm_quant_fusion_pass` for the `npugraph_ex` backend. The main changes are: - Replaced `torch.ops.npu.npu_add_rms_norm` with `torch.ops._C_ascend.npu_add_rms_norm_bias`. - For patterns without bias, `None` is passed as the bias argument. - For patterns with bias, the separate `add` operation for bias is removed and the bias is passed directly to `npu_add_rms_norm_bias`. This improves fusion. These changes ensure consistency and correctness for RMSNorm and quantization fusion patterns when using `npugraph_ex`. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: vllm-project/vllm@9562912 Signed-off-by: huyuanquan1 <huyuanquan1@huawei.com> Co-authored-by: huyuanquan1 <huyuanquan1@huawei.com>
1 parent 41d056f commit 87a0b7b

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def pattern(
5858
"""
5959
Pattern for AddRMSNormQuant fusion.
6060
"""
61-
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
61+
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
62+
rms_norm_input, residual, rms_norm_weight, None, self.eps
63+
)
6264
out0 = output[0]
6365
out1 = output[2]
6466
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
@@ -123,10 +125,11 @@ def pattern(
123125
"""
124126
Pattern for AddRMSNormQuantWithBias fusion.
125127
"""
126-
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
128+
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
129+
rms_norm_input, residual, rms_norm_weight, bias, self.eps
130+
)
127131
out0 = output[0]
128132
out1 = output[2]
129-
out0 = out0 + bias
130133
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
131134
return quantized_output, out1
132135

@@ -188,7 +191,9 @@ def pattern(
188191
"""
189192
Pattern for AddRMSNormQuantSPPattern fusion.
190193
"""
191-
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
194+
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
195+
rms_norm_input, residual, rms_norm_weight, None, self.eps
196+
)
192197
out0 = output[0]
193198
out1 = output[2]
194199
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
@@ -255,10 +260,11 @@ def pattern(
255260
"""
256261
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
257262
"""
258-
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
263+
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
264+
rms_norm_input, residual, rms_norm_weight, bias, self.eps
265+
)
259266
out0 = output[0]
260267
out1 = output[2]
261-
out0 = out0 + bias
262268
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
263269
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
264270
return quantized_output, out1

0 commit comments

Comments
 (0)