|
1 | 1 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
2 | 2 | #include <ATen/Config.h> |
3 | 3 | #include <ATen/Context.h> |
4 | | -#include <ATen/Dispatch.h> |
5 | 4 | #include <ATen/core/Tensor.h> |
6 | 5 | #include <ATen/native/mkldnn/Matmul.h> |
7 | 6 |
|
@@ -428,56 +427,74 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){ |
428 | 427 | } |
429 | 428 | } |
430 | 429 |
|
431 | | -template <typename T> |
432 | | -bool use_mkldnn_typed_matmul( |
| 430 | +bool use_mkldnn_bf16_matmul( |
433 | 431 | const Tensor& mat1, |
434 | 432 | const Tensor& mat2, |
435 | 433 | const Tensor& result) { |
436 | | - bool dtype_check = false; |
437 | | - if constexpr (std::is_same_v<T, c10::BFloat16>) { |
438 | 434 | #if defined(__aarch64__) |
439 | | - if (mkldnn_bf16_device_check_arm()) { |
440 | | - // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. |
441 | | - // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16 |
442 | | - // inputs, allow it for float as well |
443 | | - dtype_check = use_mkldnn_bf16_matmul() && |
444 | | - ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)); |
445 | | - } |
446 | | -#else |
447 | | - dtype_check = dtype_check && use_mkldnn_bf16_matmul() && |
448 | | - (mat1.scalar_type() == kBFloat16); |
| 435 | + if (mkldnn_bf16_device_check_arm()) { |
| 436 | + // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. |
| 437 | + // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16 |
| 438 | + // inputs, allow it for float as well |
| 439 | + return ( |
| 440 | + use_mkldnn_bf16_matmul() && |
| 441 | + (mat1.scalar_type() == mat2.scalar_type()) && |
| 442 | + (!result.defined() || (mat1.scalar_type() == result.scalar_type())) && |
| 443 | + ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) && |
| 444 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
| 445 | + } else |
449 | 446 | #endif |
450 | | - } else if constexpr (std::is_same_v<T, c10::Half>) { |
451 | | - dtype_check = dtype_check && use_mkldnn_fp16_matmul() && |
452 | | - (mat1.scalar_type() == kHalf); |
453 | | - } else if constexpr (std::is_same_v<T, float>) { |
454 | | - dtype_check = dtype_check && |
455 | | - (use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) && |
456 | | - (mat1.scalar_type() == kFloat); |
| 447 | + { |
| 448 | + return ( |
| 449 | + use_mkldnn_bf16_matmul() && mat1.scalar_type() == kBFloat16 && |
| 450 | + mat2.scalar_type() == kBFloat16 && |
| 451 | + (!result.defined() || result.scalar_type() == kBFloat16) && |
| 452 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
457 | 453 | } |
458 | | - if (!dtype_check) { |
459 | | - return false; |
460 | | - } |
461 | | - bool size_check = |
462 | | - mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2); |
463 | | - dtype_check = (mat1.scalar_type() == mat2.scalar_type()) && |
464 | | - (!result.defined() || result.scalar_type() == mat1.scalar_type()); |
465 | | - return dtype_check && size_check; |
| 454 | +} |
| 455 | + |
| 456 | +bool use_mkldnn_fp16_matmul( |
| 457 | + const Tensor& mat1, |
| 458 | + const Tensor& mat2, |
| 459 | + const Tensor& result) { |
| 460 | + return ( |
| 461 | + use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf && |
| 462 | + mat2.scalar_type() == kHalf && |
| 463 | + (!result.defined() || result.scalar_type() == kHalf) && |
| 464 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
| 465 | +} |
| 466 | + |
| 467 | +bool use_mkldnn_bf32_matmul( |
| 468 | + const Tensor& mat1, |
| 469 | + const Tensor& mat2, |
| 470 | + const Tensor& result) { |
| 471 | + return ( |
| 472 | + use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat && |
| 473 | + mat2.scalar_type() == kFloat && |
| 474 | + (!result.defined() || result.scalar_type() == kFloat) && |
| 475 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
| 476 | +} |
| 477 | + |
| 478 | +bool use_mkldnn_tf32_matmul( |
| 479 | + const Tensor& mat1, |
| 480 | + const Tensor& mat2, |
| 481 | + const Tensor& result) { |
| 482 | + return ( |
| 483 | + use_mkldnn_tf32_matmul() && mat1.scalar_type() == kFloat && |
| 484 | + mat2.scalar_type() == kFloat && |
| 485 | + (!result.defined() || result.scalar_type() == kFloat) && |
| 486 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
466 | 487 | } |
467 | 488 |
|
468 | 489 | bool use_mkldnn_matmul( |
469 | 490 | const Tensor& mat1, |
470 | 491 | const Tensor& mat2, |
471 | 492 | const Tensor& result) { |
472 | | - auto mat1_type = mat1.scalar_type(); |
473 | | - if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) { |
474 | | - return false; |
475 | | - } |
476 | | - AT_DISPATCH_FLOATING_TYPES_AND2( |
477 | | - kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] { |
478 | | - return use_mkldnn_typed_matmul<scalar_t>(mat1, mat2, result); |
479 | | - }); |
480 | | - return false; |
| 493 | + return ( |
| 494 | + use_mkldnn_bf16_matmul(mat1, mat2, result) || |
| 495 | + use_mkldnn_fp16_matmul(mat1, mat2, result) || |
| 496 | + use_mkldnn_bf32_matmul(mat1, mat2, result) || |
| 497 | + use_mkldnn_tf32_matmul(mat1, mat2, result)); |
481 | 498 | } |
482 | 499 |
|
483 | 500 | static void _mkldnn_matmul_i8i8i32_with_primitive( |
|
0 commit comments