|
8 | 8 |
|
9 | 9 | from vllm import _custom_ops as ops |
10 | 10 | from vllm import envs |
| 11 | +from vllm.logger import init_logger |
11 | 12 | from vllm.platforms import CpuArchEnum, current_platform |
12 | 13 | from vllm.utils.torch_utils import direct_register_custom_op |
13 | 14 |
|
| 15 | +logger = init_logger(__name__) |
| 16 | + |
14 | 17 |
|
15 | 18 | def shuffle_weight(w: torch.Tensor) -> torch.Tensor: |
16 | 19 | # Shuffle weight along the last dimension so that |
@@ -178,19 +181,28 @@ def dispatch_cpu_unquantized_gemm( |
178 | 181 | ) |
179 | 182 | if remove_weight: |
180 | 183 | layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) |
| 184 | + return |
181 | 185 | elif ( |
182 | 186 | ops._supports_onednn |
183 | 187 | and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC |
184 | 188 | ): |
185 | | - origin_weight = layer.weight |
186 | | - if remove_weight: |
187 | | - layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) |
188 | | - handler = ops.create_onednn_mm(origin_weight.t(), 32) |
189 | | - layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias) |
190 | | - else: |
191 | | - layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( |
192 | | - x, weight, bias |
193 | | - ) |
| 189 | + try: |
| 190 | + origin_weight = layer.weight |
| 191 | + handler = ops.create_onednn_mm(origin_weight.t(), 32) |
| 192 | + layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias) |
| 193 | + if remove_weight: |
| 194 | + layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) |
| 195 | + return |
| 196 | + except RuntimeError as e: |
| 197 | + logger.warning_once( |
| 198 | + "Failed to create oneDNN linear, fallback to torch linear." |
| 199 | + f" Exception: {e}" |
| 200 | + ) |
| 201 | + |
| 202 | + # fallback case |
| 203 | + layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( |
| 204 | + x, weight, bias |
| 205 | + ) |
194 | 206 |
|
195 | 207 |
|
196 | 208 | def cpu_unquantized_gemm( |
|
0 commit comments