Skip to content

Commit c3fee66

Browse files
authored
[Model] Optimizing gemma3 model's GemmaRMSNorm function (#3151)
### What this PR does / why we need it? Before optimizing,the rmsnorm time in one decoding is 531.5us. After optimizing,the rmsnorm time in one decoding is 105us. I closed the previous PR(#2456) by mistake and resubmitted it now ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@b106890 --------- Signed-off-by: socrahow <[email protected]>
1 parent dd56e93 commit c3fee66

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

vllm_ascend/ops/layernorm.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import torch
2121
from vllm.forward_context import get_forward_context
22-
from vllm.model_executor.layers.layernorm import RMSNorm
22+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
2323

2424

2525
def _addrmsnorm_forward_oot(
@@ -130,3 +130,30 @@ def forward_oot(
130130
x, residual = super().forward_oot(x, residual)
131131
return x.add_(self.bias), residual
132132
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
133+
134+
135+
class AscendGemmaRMSNorm(GemmaRMSNorm):
136+
137+
def forward_oot(
138+
self,
139+
x: torch.Tensor,
140+
residual: Optional[torch.Tensor] = None,
141+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
142+
import torch_npu
143+
144+
from vllm_ascend.utils import is_310p
145+
if residual is not None:
146+
if is_310p():
147+
orig_dtype = residual.dtype
148+
x = x + residual.to(x.dtype)
149+
residual = x.to(orig_dtype)
150+
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
151+
self.variance_epsilon)
152+
else:
153+
x, _, residual = torch_npu.npu_add_rms_norm(
154+
x, residual, 1.0 + self.weight, self.variance_epsilon)
155+
return x, residual
156+
157+
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
158+
self.variance_epsilon)
159+
return x

vllm_ascend/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
505505
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
506506
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
507507
AscendSharedFusedMoE)
508-
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
508+
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
509+
AscendQuantRMSNorm, AscendRMSNorm)
509510
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
510511
AscendMergedColumnParallelLinear,
511512
AscendQKVParallelLinear,
@@ -530,6 +531,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
530531
"ParallelLMHead": AscendParallelLMHead,
531532
"LogitsProcessor": AscendLogitsProcessor,
532533
"RMSNorm": AscendRMSNorm,
534+
"GemmaRMSNorm": AscendGemmaRMSNorm,
533535
"FusedMoE": AscendFusedMoE,
534536
"SharedFusedMoE": AscendSharedFusedMoE,
535537
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,

0 commit comments

Comments
 (0)