Skip to content

Commit 866142f

Browse files
Revert "Update the heuristic for AArch64 bmm/baddbmm (pytorch#149122)"
This reverts commit d759a51. Reverted pytorch#149122 on behalf of https://github.com/jeanschmidt due to breaking internal models, @malfet may you help merge this? ([comment](pytorch#149122 (comment)))
1 parent 5859582 commit 866142f

File tree

2 files changed

+39
-54
lines changed

2 files changed

+39
-54
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,41 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
13601360
#endif
13611361

13621362

1363+
static inline int64_t get_mkldnn_matmul_min_dim() {
1364+
static auto value = [&] {
1365+
const int64_t default_min_dim = [&] {
1366+
// Minimum dimension requirement for MKLDNN; derived based on experiments.
1367+
//it's enabled on all Neoverse cpus.
1368+
return is_arm_neoverse() ? 8 : 0;
1369+
}();
1370+
const auto value = c10::utils::get_env("TORCH_MKLDNN_MATMUL_MIN_DIM");
1371+
return value.has_value() ? std::stoi(value.value()) : default_min_dim;
1372+
}();
1373+
return value;
1374+
}
1375+
1376+
1377+
static inline int64_t get_mkldnn_matmul_min_size() {
1378+
static auto value = [&] {
1379+
const int64_t default_min_size = [&] {
1380+
// Minimum size requirement for MKLDNN; derived based on experiments.
1381+
// it's enabled on all Neoverse cpus.
1382+
return is_arm_neoverse() ? 8 * 1024 : 0;
1383+
}();
1384+
const auto value = c10::utils::get_env("TORCH_MKLDNN_MATMUL_MIN_SIZE");
1385+
return value.has_value() ? std::stoi(value.value()) : default_min_size;
1386+
}();
1387+
return value;
1388+
}
1389+
1390+
1391+
static inline bool apply_mkldnn_matmul_heur(int64_t m, int64_t k, int64_t n) {
1392+
const int64_t min_dim = get_mkldnn_matmul_min_dim();
1393+
const int64_t min_size = get_mkldnn_matmul_min_size();
1394+
return at::globalContext().userEnabledMkldnn() && m > min_dim && k > min_dim && n > min_dim && m * k * n > min_size;
1395+
}
1396+
1397+
13631398
static void addmm_impl_cpu_(
13641399
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
13651400
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
@@ -1479,7 +1514,8 @@ static void addmm_impl_cpu_(
14791514
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also
14801515
// additionally have support for running kernel with BF16 instructions
14811516
if (transpose_c) {
1482-
if (use_mkldnn_matmul(b, a, c) && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
1517+
bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1518+
if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
14831519
try {
14841520
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
14851521
// We have dispatched to ACL GEMM for single precision float
@@ -1735,7 +1771,8 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
17351771
(strides[1] == 1 && (sizes[2] == 1 || strides[2] >= sizes[1]));
17361772
};
17371773

1738-
if (use_mkldnn_matmul(batch1, batch2, self_or_result)) {
1774+
bool apply_heur = apply_mkldnn_matmul_heur(batch1.sizes()[1], batch1.sizes()[2], batch2.sizes()[2]);
1775+
if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) {
17391776
try {
17401777
mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>());
17411778
return;

aten/src/ATen/native/mkldnn/Matmul.cpp

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -391,42 +391,6 @@ void mkldnn_matmul(
391391

392392
}
393393

394-
#if AT_MKLDNN_ACL_ENABLED()
395-
// Experimentally derived heuristics for MKLDNN+ACL on NEOVERSE cores
396-
static inline int64_t get_mkldnn_acl_addmm_min_dim() {
397-
static auto value = [&] {
398-
const int64_t default_min_dim = [&] {
399-
return is_arm_neoverse() ? 8 : 0;
400-
}();
401-
const char* ptr = std::getenv("TORCH_MKLDNN_ADDMM_MIN_DIM");
402-
return ptr != nullptr ? std::atoi(ptr) : default_min_dim;
403-
}();
404-
return value;
405-
}
406-
407-
static inline int64_t get_mkldnn_acl_addmm_min_size() {
408-
static auto value = [&] {
409-
const int64_t default_min_size = [&] {
410-
return is_arm_neoverse() ? 8 * 1024 : 0;
411-
}();
412-
const char* ptr = std::getenv("TORCH_MKLDNN_ADDMM_MIN_SIZE");
413-
return ptr != nullptr ? std::atoi(ptr) : default_min_size;
414-
}();
415-
return value;
416-
}
417-
418-
static inline int64_t get_mkldnn_acl_bmm_baddbmm_threshold() {
419-
static auto value = [&] {
420-
const int64_t default_threshold = [&] {
421-
return is_arm_neoverse() ? 1L << 22 : 0;
422-
}();
423-
const char* ptr = std::getenv("TORCH_MKLDNN_BMM_BADDBMM_THRESHOLD");
424-
return ptr != nullptr ? std::atoi(ptr) : default_threshold;
425-
}();
426-
return value;
427-
}
428-
#endif
429-
430394
static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
431395
// if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
432396
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
@@ -441,26 +405,10 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
441405
return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size;
442406
} else if (mat2.dim() == 2 && mat2.dim() == 2) {
443407
// aten::addmm
444-
#if AT_MKLDNN_ACL_ENABLED()
445-
const int64_t mkldnn_acl_addmm_min_dim = get_mkldnn_acl_addmm_min_dim();
446-
const int64_t mkldnn_acl_addmm_min_size = get_mkldnn_acl_addmm_min_size();
447-
// M > MIN_DIM and N > MIN_DIM and K > MIN_DIM and M*N*K > MIN_SIZE
448-
return mat1.size(0) > mkldnn_acl_addmm_min_dim
449-
&& mat1.size(1) > mkldnn_acl_addmm_min_dim
450-
&& mat2.size(1) > mkldnn_acl_addmm_min_dim
451-
&& mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_acl_addmm_min_size;
452-
#else
453408
return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size;
454-
#endif
455409
} else {
456410
// aten::bmm, aten::baddbmm
457-
#if AT_MKLDNN_ACL_ENABLED()
458-
const int64_t mkldnn_acl_bmm_baddbmm_threshold = get_mkldnn_acl_bmm_baddbmm_threshold();
459-
// BATCH_SIZE^2 * M * N * K >= THRESHOLD
460-
return mat1.size(0) * mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) >= mkldnn_acl_bmm_baddbmm_threshold;
461-
#else
462411
return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size;
463-
#endif
464412
}
465413
}
466414

0 commit comments

Comments
 (0)