Skip to content

Commit d34f5fe

Browse files
bigPYJ1151Isotr0py
andauthored
[Bugfix][CPU] Fallback oneDNN linear to torch linear to fix half gemm support on legecy platforms (#27526)
Signed-off-by: jiang1.li <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent bdb01a3 commit d34f5fe

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

docker/Dockerfile.cpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc
7979
######################### BUILD IMAGE #########################
8080
FROM base AS vllm-build
8181

82-
ARG max_jobs=2
82+
ARG max_jobs=32
8383
ENV MAX_JOBS=${max_jobs}
8484

8585
ARG GIT_REPO_CHECK=0

vllm/model_executor/layers/utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88

99
from vllm import _custom_ops as ops
1010
from vllm import envs
11+
from vllm.logger import init_logger
1112
from vllm.platforms import CpuArchEnum, current_platform
1213
from vllm.utils.torch_utils import direct_register_custom_op
1314

15+
logger = init_logger(__name__)
16+
1417

1518
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
1619
# Shuffle weight along the last dimension so that
@@ -178,19 +181,28 @@ def dispatch_cpu_unquantized_gemm(
178181
)
179182
if remove_weight:
180183
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
184+
return
181185
elif (
182186
ops._supports_onednn
183187
and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC
184188
):
185-
origin_weight = layer.weight
186-
if remove_weight:
187-
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
188-
handler = ops.create_onednn_mm(origin_weight.t(), 32)
189-
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
190-
else:
191-
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
192-
x, weight, bias
193-
)
189+
try:
190+
origin_weight = layer.weight
191+
handler = ops.create_onednn_mm(origin_weight.t(), 32)
192+
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
193+
if remove_weight:
194+
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
195+
return
196+
except RuntimeError as e:
197+
logger.warning_once(
198+
"Failed to create oneDNN linear, fallback to torch linear."
199+
f" Exception: {e}"
200+
)
201+
202+
# fallback case
203+
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
204+
x, weight, bias
205+
)
194206

195207

196208
def cpu_unquantized_gemm(

0 commit comments

Comments
 (0)