@@ -1360,6 +1360,41 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
13601360#endif
13611361
13621362
1363+ static inline int64_t get_mkldnn_matmul_min_dim () {
1364+ static auto value = [&] {
1365+ const int64_t default_min_dim = [&] {
1366+ // Minimum dimension requirement for MKLDNN; derived based on experiments.
1367+ // it's enabled on all Neoverse cpus.
1368+ return is_arm_neoverse () ? 8 : 0 ;
1369+ }();
1370+ const auto value = c10::utils::get_env (" TORCH_MKLDNN_MATMUL_MIN_DIM" );
1371+ return value.has_value () ? std::stoi (value.value ()) : default_min_dim;
1372+ }();
1373+ return value;
1374+ }
1375+
1376+
1377+ static inline int64_t get_mkldnn_matmul_min_size () {
1378+ static auto value = [&] {
1379+ const int64_t default_min_size = [&] {
1380+ // Minimum size requirement for MKLDNN; derived based on experiments.
1381+ // it's enabled on all Neoverse cpus.
1382+ return is_arm_neoverse () ? 8 * 1024 : 0 ;
1383+ }();
1384+ const auto value = c10::utils::get_env (" TORCH_MKLDNN_MATMUL_MIN_SIZE" );
1385+ return value.has_value () ? std::stoi (value.value ()) : default_min_size;
1386+ }();
1387+ return value;
1388+ }
1389+
1390+
1391+ static inline bool apply_mkldnn_matmul_heur (int64_t m, int64_t k, int64_t n) {
1392+ const int64_t min_dim = get_mkldnn_matmul_min_dim ();
1393+ const int64_t min_size = get_mkldnn_matmul_min_size ();
1394+ return at::globalContext ().userEnabledMkldnn () && m > min_dim && k > min_dim && n > min_dim && m * k * n > min_size;
1395+ }
1396+
1397+
13631398static void addmm_impl_cpu_ (
13641399 Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
13651400 TORCH_INTERNAL_ASSERT (self.dim () == 2 && m1.dim () == 2 && m2.dim () == 2 );
@@ -1479,7 +1514,8 @@ static void addmm_impl_cpu_(
14791514 // that will call then into Arm® Compute Library (ACL) GEMM kernel and also
14801515 // additionally have support for running kernel with BF16 instructions
14811516 if (transpose_c) {
1482- if (use_mkldnn_matmul (b, a, c) && transpose_a && !transpose_b && result.scalar_type () == at::ScalarType::Float) {
1517+ bool apply_heur = apply_mkldnn_matmul_heur (b.sizes ()[0 ], b.sizes ()[1 ], a.sizes ()[1 ]);
1518+ if (apply_heur && transpose_a && !transpose_b && result.scalar_type () == at::ScalarType::Float) {
14831519 try {
14841520 mkldnn_matmul (b, a, c, beta.to <float >(), alpha.to <float >());
14851521 // We have dispatched to ACL GEMM for single precision float
@@ -1735,7 +1771,8 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
17351771 (strides[1 ] == 1 && (sizes[2 ] == 1 || strides[2 ] >= sizes[1 ]));
17361772 };
17371773
1738- if (use_mkldnn_matmul (batch1, batch2, self_or_result)) {
1774+ bool apply_heur = apply_mkldnn_matmul_heur (batch1.sizes ()[1 ], batch1.sizes ()[2 ], batch2.sizes ()[2 ]);
1775+ if (apply_heur && use_mkldnn_matmul (batch1, batch2, self_or_result)) {
17391776 try {
17401777 mkldnn_matmul (batch1, batch2, self_or_result, beta.to <float >(), alpha.to <float >());
17411778 return ;
0 commit comments