Skip to content

Commit c721ae6

Browse files
authored
[CustomOp] Register RMSNorm instead of overwrite forward_oot (#2284)
### What this PR does / why we need it? Use function CustomOp.register_oot to achieve the customop registery ``` from vllm.model_executor.custom_op import CustomOp CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm") ``` ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@afa5b7c --------- Signed-off-by: Icey <[email protected]>
1 parent e14f2ef commit c721ae6

File tree

4 files changed

+85
-28
lines changed

4 files changed

+85
-28
lines changed

tests/ut/ops/test_layernorm.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
import torch
5+
from vllm.model_executor.layers.layernorm import RMSNorm
6+
7+
8+
@pytest.fixture
9+
def dummy_tensor():
10+
return torch.randn(4, 8, dtype=torch.float16)
11+
12+
13+
def mock_rms_norm(x, weight, eps):
14+
return x + 1, None
15+
16+
17+
def mock_add_rms_norm(x, residual, weight, eps):
18+
return 2 * x, None, 2 * residual
19+
20+
21+
@pytest.mark.parametrize("is_310p_return", [True, False])
22+
@pytest.mark.parametrize("residual",
23+
[None, torch.randn(4, 8, dtype=torch.float32)])
24+
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
25+
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
26+
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
27+
residual, dummy_tensor):
28+
29+
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
30+
layer = RMSNorm(hidden_size=32, eps=1e-05)
31+
if residual is not None:
32+
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
33+
34+
if is_310p_return:
35+
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
36+
expected_out_x = expected_arg_x + 1
37+
expected_out_residual = expected_arg_x.to(residual.dtype)
38+
39+
mock_rmsnorm.assert_called_once()
40+
assert torch.allclose(out_x, expected_out_x)
41+
assert torch.allclose(out_residual, expected_out_residual)
42+
else:
43+
expected_out_x = 2 * dummy_tensor
44+
expected_out_residual = 2 * residual
45+
mock_add_rmsnorm.assert_called_once()
46+
assert torch.allclose(out_x, expected_out_x)
47+
assert torch.allclose(out_residual, expected_out_residual)
48+
else:
49+
out_x = layer.forward(dummy_tensor, residual)
50+
expected_out_x = dummy_tensor + 1
51+
52+
mock_rmsnorm.assert_called_once()
53+
assert torch.allclose(out_x, expected_out_x)

tests/ut/test_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,20 +347,22 @@ def test_update_aclgraph_sizes(self):
347347
@mock.patch("vllm.model_executor.custom_op.CustomOp")
348348
@mock.patch("vllm_ascend.ops.activation.AscendQuickGELU")
349349
@mock.patch("vllm_ascend.ops.activation.AscendSiluAndMul")
350-
def test_register_ascend_customop(self, mock_ascend_silu_and_mul,
350+
@mock.patch("vllm_ascend.ops.layernorm.AscendRMSNorm")
351+
def test_register_ascend_customop(self, mock_ascend_rmsnorm,
352+
mock_ascend_silu_and_mul,
351353
mock_ascend_quick_gelu, mock_customop):
352354
utils._ASCEND_CUSTOMOP_IS_REIGISTERED = False
353355

354356
# ascend custom op is not registered
355357
utils.register_ascend_customop()
356-
# should call register_oot twice
357-
self.assertEqual(mock_customop.register_oot.call_count, 2)
358+
# should call register_oot three
359+
self.assertEqual(mock_customop.register_oot.call_count, 3)
358360
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
359361

360362
# ascend custom op is already registered
361363
utils.register_ascend_customop()
362-
# should not register_oot again, thus only called twice in this ut
363-
self.assertEqual(mock_customop.register_oot.call_count, 2)
364+
# should not register_oot again, thus only called three in this ut
365+
self.assertEqual(mock_customop.register_oot.call_count, 3)
364366

365367

366368
class TestProfileExecuteDuration(TestBase):

vllm_ascend/ops/layernorm.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import torch
2121
from vllm.model_executor.layers.layernorm import RMSNorm
2222

23-
from vllm_ascend.utils import is_310p
24-
2523

2624
class AddRMSNormW8A8Quant(RMSNorm):
2725
# Fuse AddRmsNorm and W8A8 quantization ops together
@@ -60,27 +58,28 @@ def forward(
6058
return x
6159

6260

63-
def forward_oot(
64-
self,
65-
x: torch.Tensor,
66-
residual: Optional[torch.Tensor] = None,
67-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
68-
import torch_npu
69-
70-
if residual is not None:
71-
if is_310p():
72-
orig_dtype = residual.dtype
73-
x = x + residual.to(x.dtype)
74-
residual = x.to(orig_dtype)
75-
x, _ = torch_npu.npu_rms_norm(x, self.weight,
76-
self.variance_epsilon)
77-
else:
78-
x, _, residual = torch_npu.npu_add_rms_norm(
79-
x, residual, self.weight, self.variance_epsilon)
80-
return x, residual
61+
class AscendRMSNorm(RMSNorm):
8162

82-
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
83-
return x
63+
def forward_oot(
64+
self,
65+
x: torch.Tensor,
66+
residual: Optional[torch.Tensor] = None,
67+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
68+
import torch_npu
8469

70+
from vllm_ascend.utils import is_310p
71+
if residual is not None:
72+
if is_310p():
73+
orig_dtype = residual.dtype
74+
x = x + residual.to(x.dtype)
75+
residual = x.to(orig_dtype)
76+
x, _ = torch_npu.npu_rms_norm(x, self.weight,
77+
self.variance_epsilon)
78+
else:
79+
x, _, residual = torch_npu.npu_add_rms_norm(
80+
x, residual, self.weight, self.variance_epsilon)
81+
return x, residual
8582

86-
RMSNorm.forward_oot = forward_oot
83+
x, residual = torch_npu.npu_rms_norm(x, self.weight,
84+
self.variance_epsilon)
85+
return x

vllm_ascend/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,9 @@ def register_ascend_customop():
479479
CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul,
480480
name="SiluAndMul")
481481

482+
from vllm_ascend.ops.layernorm import AscendRMSNorm
483+
CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")
484+
482485
# NOTE: Keep this at last to ensure all custom actions are registered
483486
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
484487

0 commit comments

Comments
 (0)