|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
4 |
| -import logging |
5 | 4 | from typing import Optional, Union
|
6 | 5 |
|
7 | 6 | import torch
|
8 | 7 | import torch.nn as nn
|
9 | 8 |
|
10 |
| -logger = logging.getLogger("vllm_xpu_kernel") |
11 |
| - |
12 |
| - |
13 |
| -class CustomOp(nn.Module): |
14 |
| - """ |
15 |
| - Base class for custom ops. |
16 |
| - Dispatches the forward method to the appropriate backend. |
17 |
| - """ |
18 |
| - |
19 |
| - def __new__(cls, *args, **kwargs): |
20 |
| - try: |
21 |
| - op_name = cls.__name__ |
22 |
| - except AttributeError: |
23 |
| - raise TypeError( |
24 |
| - f"Cannot instantiate '{cls.__name__}': its 'name' attribute " |
25 |
| - f"was not set, possibly because it was not decorated with " |
26 |
| - f"@CustomOp.register, or it's the CustomOp base class itself." |
27 |
| - ) from None |
28 |
| - logger.debug("Instantiating custom op: %s", op_name) |
29 |
| - op_cls_to_instantiate = cls |
30 |
| - return super().__new__(op_cls_to_instantiate) |
31 |
| - |
32 |
| - def __init__(self): |
33 |
| - super().__init__() |
34 |
| - self._forward_method = self.dispatch_forward() |
35 |
| - |
36 |
| - def forward(self, *args, **kwargs): |
37 |
| - return self._forward_method(*args, **kwargs) |
38 |
| - |
39 |
| - def forward_native(self, *args, **kwargs): |
40 |
| - """PyTorch-native implementation of the forward method. |
41 |
| - This method is optional. If implemented, it can be used with compilers |
42 |
| - such as torch.compile or PyTorch XLA. Also, it can be used for testing |
43 |
| - purposes. |
44 |
| - """ |
45 |
| - raise NotImplementedError |
46 |
| - |
47 |
| - def forward_cuda(self, *args, **kwargs): |
48 |
| - raise NotImplementedError |
49 |
| - |
50 |
| - def forward_xpu(self, *args, **kwargs): |
51 |
| - # By default, we assume that XPU ops are compatible with the |
52 |
| - # PyTorch-native implementation. |
53 |
| - return self.forward_native(*args, **kwargs) |
54 |
| - |
55 |
| - def forward_cpu(self, *args, **kwargs): |
56 |
| - # By default, we assume that CPU ops are compatible with CUDA ops. |
57 |
| - return self.forward_cuda(*args, **kwargs) |
58 |
| - |
59 |
| - def dispatch_forward(self): |
60 |
| - return self.forward_xpu |
| 9 | +from tests.ops.custom_ops import CustomOp |
61 | 10 |
|
62 | 11 |
|
63 | 12 | def fused_add_rms_norm(
|
@@ -181,17 +130,7 @@ def forward_xpu(
|
181 | 130 | x: torch.Tensor,
|
182 | 131 | residual: Optional[torch.Tensor] = None,
|
183 | 132 | ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
184 |
| - if self.variance_size_override is not None: |
185 |
| - return self.forward_native(x, residual) |
186 |
| - |
187 |
| - add_residual = residual is not None |
188 |
| - norm_func = dispatch_cuda_rmsnorm_func(add_residual) |
189 |
| - |
190 |
| - if add_residual: |
191 |
| - return norm_func(x, residual, self.weight.data, |
192 |
| - self.variance_epsilon) |
193 |
| - else: |
194 |
| - return norm_func(x, self.weight.data, self.variance_epsilon) |
| 133 | + return self.forward_cuda(x, residual) |
195 | 134 |
|
196 | 135 | def extra_repr(self) -> str:
|
197 | 136 | s = f"hidden_size={self.weight.data.size(0)}"
|
|
0 commit comments