From d4b80e6f8a7851f51cc85eaf6ffd6a1e1ba3ab9b Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 29 Oct 2024 18:05:26 +0000 Subject: [PATCH 1/4] registering float16 matmul. wip --- .../providers/cpu/cpu_execution_provider.cc | 30 ++++++++++++++++++- onnxruntime/core/providers/cpu/math/matmul.cc | 28 +++++++++++++++++ onnxruntime/core/util/math_cpu.cc | 17 +++++++++++ .../test/providers/cpu/math/matmul_test.cc | 2 +- 4 files changed, 75 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index acb7001501..3d4e8243e2 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -144,6 +144,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Aco class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, MatMul); +#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); @@ -344,6 +347,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, MatMul); +#endif class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); @@ -514,6 +520,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Sp class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, MatMul); +#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); @@ -620,6 +629,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul); @@ -2814,7 +2826,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - + }; for (auto& function_table_entry : function_table) { @@ -2827,6 +2839,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { return Status::OK(); } +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is defined") +#else +#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is NOT defined") +#endif + + #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -2853,6 +2872,14 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { MLFloat16, LeakyRelu)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo }; for (auto& function_table_entry : function_table) { @@ -3104,6 +3131,7 @@ Status RegisterCPUKernels(KernelRegistry& kernel_registry) { ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry)); #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED if (MlasFp16AccelerationSupported()) { + #pragma message("calling RegisterFp16Kernels") ORT_RETURN_IF_ERROR(RegisterFp16Kernels(kernel_registry)); } #endif diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 2c6d23e4de..8b906bc1a0 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -88,6 +88,34 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( .TypeConstraint("T", BuildKernelDefConstraints()), MatMul); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 7, 8, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, 10, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 11, 12, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_TYPED_KERNEL( + MatMul, + 13, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + template Status MatMul::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 135b4bb4c7..3a2e924ea3 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -50,6 +50,23 @@ EIGEN_MATMUL_FUNCTION(uint32_t) EIGEN_MATMUL_FUNCTION(int64_t) EIGEN_MATMUL_FUNCTION(uint64_t) + +template <> +void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*) { + // Convert MLFloat16* to Eigen::half* using reinterpret_cast + const Eigen::half* A_half = reinterpret_cast(A); + const Eigen::half* B_half = reinterpret_cast(B); + Eigen::half* C_half = reinterpret_cast(C); + + // Perform matrix multiplication using Eigen + auto C_mat = Eigen::Map>(C_half, N, M); + C_mat.noalias() = Eigen::Map>(B_half, N, K) * + Eigen::Map>(A_half, K, M); +} + +// template void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*); + + //////////////////////////////////////////////////////////////////////////////// // BLAS alternatives. // Depending on whether we have specified an external BLAS library or not, we diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 298e870f34..2da909aada 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -311,7 +311,7 @@ TEST(MathOpTest, MatMul_Float16) { run_test(true); run_test(false); } -#endif +// #endif #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) TEST(MathOpTest, MatMul_bfloat16) { From ba450bcf3cbd264039b602abcb2934a0dfaa2a1a Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Thu, 31 Oct 2024 00:08:29 +0000 Subject: [PATCH 2/4] fixed validation issue --- onnxruntime/core/util/math_cpu.cc | 43 ++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 3a2e924ea3..a200107bc1 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -51,18 +51,25 @@ EIGEN_MATMUL_FUNCTION(int64_t) EIGEN_MATMUL_FUNCTION(uint64_t) -template <> -void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*) { - // Convert MLFloat16* to Eigen::half* using reinterpret_cast - const Eigen::half* A_half = reinterpret_cast(A); - const Eigen::half* B_half = reinterpret_cast(B); - Eigen::half* C_half = reinterpret_cast(C); - - // Perform matrix multiplication using Eigen - auto C_mat = Eigen::Map>(C_half, N, M); - C_mat.noalias() = Eigen::Map>(B_half, N, K) * - Eigen::Map>(A_half, K, M); -} + + +// template <> +// void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) { +// // // Convert MLFloat16* to Eigen::half* using reinterpret_cast +// // const Eigen::half* A_half = reinterpret_cast(A); +// // const Eigen::half* B_half = reinterpret_cast(B); +// // Eigen::half* C_half = reinterpret_cast(C); + +// // // Perform matrix multiplication using Eigen +// // auto C_mat = Eigen::Map>(C_half, M, N); +// // C_mat.noalias() = Eigen::Map>(A_half, M, K) * +// // Eigen::Map>(B_half, K, N); + +// // Optionally, handle threading with thread_pool if needed (not shown here). + +// math::Gemm(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast(&alpha), +// reinterpret_cast(A), reinterpret_cast(B), *reinterpret_cast(&beta), reinterpret_cast(C), thread_pool); +// } // template void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*); @@ -202,6 +209,18 @@ void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool); } +template <> +void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) { + // Set alpha to 1 and beta to 0 + Eigen::half alpha = Eigen::half(1.0f); + Eigen::half beta = Eigen::half(0.0f); + + // Use GEMM with the given parameters + math::Gemm(CblasNoTrans, CblasNoTrans, M, N, K, alpha, + reinterpret_cast(A), reinterpret_cast(B), beta, + reinterpret_cast(C), thread_pool); +} + #ifdef MLAS_SUPPORTS_GEMM_DOUBLE template <> void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, const double* B, double* C, ThreadPool* threadpool) { From 5eb5d95b2f35bbf43abf6e9aed062904471096a7 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Wed, 6 Nov 2024 18:16:47 +0000 Subject: [PATCH 3/4] bypassed optimization pass which introduced casts to fp32 . also added faster implementation of fp16 MatMul --- .../core/optimizer/insert_cast_transformer.cc | 12 ++-- onnxruntime/core/util/math_cpu.cc | 58 ++++++++++++++++--- 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 67ebc22dab..9550906071 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -221,12 +221,12 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: } static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { - for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { - // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 - node.SetExecutionProviderType(""); - } - } + // for (auto& node : graph.Nodes()) { + // if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { + // // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 + // node.SetExecutionProviderType(""); + // } + // } return Status::OK(); } diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index a200107bc1..48663a9684 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -209,18 +209,58 @@ void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool); } + template <> -void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) { - // Set alpha to 1 and beta to 0 - Eigen::half alpha = Eigen::half(1.0f); - Eigen::half beta = Eigen::half(0.0f); - - // Use GEMM with the given parameters - math::Gemm(CblasNoTrans, CblasNoTrans, M, N, K, alpha, - reinterpret_cast(A), reinterpret_cast(B), beta, - reinterpret_cast(C), thread_pool); +void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* a_data, const MLFloat16* b_data, MLFloat16* y_data, concurrency::ThreadPool* thread_pool) { + +MLFloat16 alpha = MLFloat16(1.0f); +MLFloat16 beta = MLFloat16(0.0f); + // if input is empty tensor, return directly as nothing need to be calculated. + if (M == 0 || N == 0) + return; + +#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#endif + +memset(&beta, 0, sizeof(MLFloat16)); +#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS) +#pragma GCC diagnostic pop +#endif +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + + + MLAS_HALF_GEMM_DATA_PARAMS data; + data.A = a_data; + data.lda = K; + data.B = b_data; + data.ldb = N; + data.C = y_data; + data.ldc = N; + // if (c_shape != nullptr) { + // data.Bias = c_data; + // } + MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool); + return; + +#endif + // Fallback to Eigen + // // Broadcast the bias as needed if bias is given + // GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data); +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + math::Gemm(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast(&alpha), + reinterpret_cast(a_data), reinterpret_cast(b_data), *reinterpret_cast(&beta), reinterpret_cast(y_data), thread_pool); +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + } + #ifdef MLAS_SUPPORTS_GEMM_DOUBLE template <> void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, const double* B, double* C, ThreadPool* threadpool) { From 5a40b3e0948453fc45e6b9117a1d268cbddd28ad Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 12 May 2025 21:48:57 +0000 Subject: [PATCH 4/4] fixed issues with matmul fp16 --- onnxruntime/core/providers/cpu/math/matmul.cc | 7 ++++++- onnxruntime/test/providers/cpu/math/matmul_test.cc | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 8b906bc1a0..f78e12523e 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -136,7 +136,12 @@ Status MatMul::Compute(OpKernelContext* ctx) const { // be filled out with zeros. EigenMatrixMapRowMajor dest(y->MutableData(), narrow(helper.M()), narrow(helper.N())); - dest.setZero(); + if constexpr (std::is_same::value) { + dest.setConstant(MLFloat16(0.0f)); + } else { + dest.setZero(); + } + return Status::OK(); } diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 2da909aada..298e870f34 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -311,7 +311,7 @@ TEST(MathOpTest, MatMul_Float16) { run_test(true); run_test(false); } -// #endif +#endif #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) TEST(MathOpTest, MatMul_bfloat16) {