Skip to content

Commit 97d4c29

Browse files
authored
refactor files by ops in test folder (#4)
Signed-off-by: Kunshang Ji <[email protected]>
1 parent 226af61 commit 97d4c29

File tree

4 files changed

+62
-65
lines changed

4 files changed

+62
-65
lines changed

benchmark/benchmark_layernorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from tests.ops import RMSNorm
9+
from tests.ops.layernorm_op import RMSNorm
1010
from tests.utils import STR_DTYPE_TO_TORCH_DTYPE
1111

1212

tests/ops/custom_ops.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import logging
5+
6+
import torch.nn as nn
7+
8+
logger = logging.getLogger("vllm_xpu_kernel")
9+
10+
11+
class CustomOp(nn.Module):
12+
"""
13+
Base class for custom ops.
14+
Dispatches the forward method to the appropriate backend.
15+
"""
16+
17+
def __new__(cls, *args, **kwargs):
18+
try:
19+
op_name = cls.__name__
20+
except AttributeError:
21+
raise TypeError(
22+
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
23+
f"was not set, possibly because it was not decorated with "
24+
f"@CustomOp.register, or it's the CustomOp base class itself."
25+
) from None
26+
logger.debug("Instantiating custom op: %s", op_name)
27+
op_cls_to_instantiate = cls
28+
return super().__new__(op_cls_to_instantiate)
29+
30+
def __init__(self):
31+
super().__init__()
32+
self._forward_method = self.dispatch_forward()
33+
34+
def forward(self, *args, **kwargs):
35+
return self._forward_method(*args, **kwargs)
36+
37+
def forward_native(self, *args, **kwargs):
38+
"""PyTorch-native implementation of the forward method.
39+
This method is optional. If implemented, it can be used with compilers
40+
such as torch.compile or PyTorch XLA. Also, it can be used for testing
41+
purposes.
42+
"""
43+
raise NotImplementedError
44+
45+
def forward_cuda(self, *args, **kwargs):
46+
raise NotImplementedError
47+
48+
def forward_xpu(self, *args, **kwargs):
49+
# By default, we assume that XPU ops are compatible with the
50+
# PyTorch-native implementation.
51+
return self.forward_native(*args, **kwargs)
52+
53+
def forward_cpu(self, *args, **kwargs):
54+
# By default, we assume that CPU ops are compatible with CUDA ops.
55+
return self.forward_cuda(*args, **kwargs)
56+
57+
def dispatch_forward(self):
58+
return self.forward_xpu

tests/ops.py renamed to tests/ops/layernorm_op.py

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import logging
54
from typing import Optional, Union
65

76
import torch
87
import torch.nn as nn
98

10-
logger = logging.getLogger("vllm_xpu_kernel")
11-
12-
13-
class CustomOp(nn.Module):
14-
"""
15-
Base class for custom ops.
16-
Dispatches the forward method to the appropriate backend.
17-
"""
18-
19-
def __new__(cls, *args, **kwargs):
20-
try:
21-
op_name = cls.__name__
22-
except AttributeError:
23-
raise TypeError(
24-
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
25-
f"was not set, possibly because it was not decorated with "
26-
f"@CustomOp.register, or it's the CustomOp base class itself."
27-
) from None
28-
logger.debug("Instantiating custom op: %s", op_name)
29-
op_cls_to_instantiate = cls
30-
return super().__new__(op_cls_to_instantiate)
31-
32-
def __init__(self):
33-
super().__init__()
34-
self._forward_method = self.dispatch_forward()
35-
36-
def forward(self, *args, **kwargs):
37-
return self._forward_method(*args, **kwargs)
38-
39-
def forward_native(self, *args, **kwargs):
40-
"""PyTorch-native implementation of the forward method.
41-
This method is optional. If implemented, it can be used with compilers
42-
such as torch.compile or PyTorch XLA. Also, it can be used for testing
43-
purposes.
44-
"""
45-
raise NotImplementedError
46-
47-
def forward_cuda(self, *args, **kwargs):
48-
raise NotImplementedError
49-
50-
def forward_xpu(self, *args, **kwargs):
51-
# By default, we assume that XPU ops are compatible with the
52-
# PyTorch-native implementation.
53-
return self.forward_native(*args, **kwargs)
54-
55-
def forward_cpu(self, *args, **kwargs):
56-
# By default, we assume that CPU ops are compatible with CUDA ops.
57-
return self.forward_cuda(*args, **kwargs)
58-
59-
def dispatch_forward(self):
60-
return self.forward_xpu
9+
from tests.ops.custom_ops import CustomOp
6110

6211

6312
def fused_add_rms_norm(
@@ -181,17 +130,7 @@ def forward_xpu(
181130
x: torch.Tensor,
182131
residual: Optional[torch.Tensor] = None,
183132
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
184-
if self.variance_size_override is not None:
185-
return self.forward_native(x, residual)
186-
187-
add_residual = residual is not None
188-
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
189-
190-
if add_residual:
191-
return norm_func(x, residual, self.weight.data,
192-
self.variance_epsilon)
193-
else:
194-
return norm_func(x, self.weight.data, self.variance_epsilon)
133+
return self.forward_cuda(x, residual)
195134

196135
def extra_repr(self) -> str:
197136
s = f"hidden_size={self.weight.data.size(0)}"

tests/test_layernorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import torch
66

7-
from tests.ops import RMSNorm
7+
from tests.ops.layernorm_op import RMSNorm
88
from tests.utils import opcheck
99

1010
DTYPES = [torch.half, torch.bfloat16]

0 commit comments

Comments
 (0)