Skip to content

Commit 5b37249

Browse files
renato-arantespytorchmergebot
authored andcommitted
Enable fp16 linear layers in PyTorch via ACL (pytorch#144992)
This pull request aims to enable the use of linear layers with the fp16 data type through the ACL. On a Graviton3 instance running with 16 threads, `torch.randn(2048, 4096, dtype=torch.half)` will take 50+% less time to complete compared with `torch.randn(2048, 4096, dtype=torch.float32)`. Pull Request resolved: pytorch#144992 Approved by: https://github.com/ng-05, https://github.com/digantdesai, https://github.com/malfet
1 parent 6d4f5f7 commit 5b37249

File tree

4 files changed

+35
-11
lines changed

4 files changed

+35
-11
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,8 +1513,12 @@ static void addmm_impl_cpu_(
15131513
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also
15141514
// additionally have support for running kernel with BF16 instructions
15151515
if (transpose_c) {
1516-
bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1517-
if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
1516+
bool apply_heur =
1517+
apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1518+
if (apply_heur && transpose_a && !transpose_b &&
1519+
(result.scalar_type() == at::ScalarType::Float ||
1520+
result.scalar_type() == at::ScalarType::BFloat16 ||
1521+
result.scalar_type() == at::ScalarType::Half)) {
15181522
try {
15191523
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
15201524
// We have dispatched to ACL GEMM for single precision float

aten/src/ATen/native/mkldnn/Matmul.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,27 @@ void mkldnn_matmul(
236236
"mkldnn_matmul: unsupported dims for mat and mat2");
237237

238238
#if defined(__aarch64__)
239-
// oneDNN fast-maths mode (enabled by setting the environment variable ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch
240-
// fp32 inputs to bf16 kernels where HW permits. So, both fp32 and bf16 inputs are permitted.
241-
TORCH_CHECK((mat1.scalar_type() == mat2.scalar_type()) && (mat1.scalar_type() == result.scalar_type()) &&
242-
((mat1.scalar_type() == at::kFloat) || (mat1.scalar_type() == at::kBFloat16)),
243-
"mkldnn_matmul: only enabled for fp32 and bf16 path");
239+
// oneDNN fast-maths mode (enabled by setting the environment variable
240+
// ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch fp32 inputs to bf16 kernels
241+
// where HW permits. So, both fp32 and bf16 inputs are permitted.
242+
TORCH_CHECK(
243+
(mat1.scalar_type() == mat2.scalar_type()) &&
244+
(mat1.scalar_type() == result.scalar_type()) &&
245+
((mat1.scalar_type() == at::kFloat) ||
246+
(mat1.scalar_type() == at::kBFloat16) ||
247+
(mat1.scalar_type() == at::kHalf)),
248+
"mkldnn_matmul: only enabled for fp32, bf16 and fp16 path");
244249
// device needs to support bf16 if the inputs are of bf16 type
245250
if (mat1.scalar_type() == at::kBFloat16) {
246-
TORCH_CHECK(mkldnn_bf16_device_check_arm(),
247-
"mkldnn_matmul: mkldnn_matmul bf16 path needs a cpu with bf16 support");
251+
TORCH_CHECK(
252+
mkldnn_bf16_device_check_arm(),
253+
"mkldnn_matmul: mkldnn_matmul bf16 path needs a cpu with bf16 support");
254+
}
255+
// device needs to support fp16 if the inputs are of fp16 type
256+
if (mat1.scalar_type() == at::kHalf) {
257+
TORCH_CHECK(
258+
mkldnn_fp16_device_check_arm(),
259+
"mkldnn_matmul: mkldnn_matmul fp16 path needs a cpu with fp16 support");
248260
}
249261
#else
250262
TORCH_CHECK(

aten/src/ATen/native/mkldnn/Utils.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ inline bool mkldnn_bf16_device_check_arm() {
9090
return cpuinfo_initialize() && cpuinfo_has_arm_bf16();
9191
}
9292

93+
inline bool mkldnn_fp16_device_check_arm() {
94+
return cpuinfo_initialize() && cpuinfo_has_arm_neon_fp16();
95+
}
96+
9397
inline bool is_arm_neoverse() {
9498
return (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 &&
9599
(cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1 ||
@@ -102,6 +106,10 @@ constexpr bool mkldnn_bf16_device_check_arm() {
102106
return false;
103107
}
104108

109+
inline bool mkldnn_fp16_device_check_arm() {
110+
return false;
111+
}
112+
105113
constexpr bool is_arm_neoverse() {
106114
return false;
107115
}
@@ -121,7 +129,7 @@ inline bool mkldnn_fp16_device_check() {
121129
#if defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))
122130
return ideep::has_fp16_type_support();
123131
#else
124-
return false;
132+
return mkldnn_fp16_device_check_arm();
125133
#endif
126134
}
127135

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4):
117117
):
118118
input_kernel = 1
119119
if output.is_contiguous(memory_format=torch.contiguous_format) or (
120-
TEST_ACL and dtype == torch.bfloat16
120+
TEST_ACL and (dtype == torch.bfloat16 or dtype == torch.half)
121121
):
122122
output_kernel = 1
123123
return input_kernel + output_kernel

0 commit comments

Comments
 (0)