diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 67ebc22dab..9550906071 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -221,12 +221,12 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: } static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { - for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { - // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 - node.SetExecutionProviderType(""); - } - } + // for (auto& node : graph.Nodes()) { + // if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { + // // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 + // node.SetExecutionProviderType(""); + // } + // } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index acb7001501..3d4e8243e2 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -144,6 +144,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Aco class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, MatMul); +#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); @@ -344,6 +347,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, MatMul); +#endif class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); @@ -514,6 +520,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Sp class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, MatMul); +#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); @@ -620,6 +629,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul); @@ -2814,7 +2826,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - + }; for (auto& function_table_entry : function_table) { @@ -2827,6 +2839,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { return Status::OK(); } +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is defined") +#else +#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is NOT defined") +#endif + + #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -2853,6 +2872,14 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { MLFloat16, LeakyRelu)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo }; for (auto& function_table_entry : function_table) { @@ -3104,6 +3131,7 @@ Status RegisterCPUKernels(KernelRegistry& kernel_registry) { ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry)); #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED if (MlasFp16AccelerationSupported()) { + #pragma message("calling RegisterFp16Kernels") ORT_RETURN_IF_ERROR(RegisterFp16Kernels(kernel_registry)); } #endif diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 2c6d23e4de..f78e12523e 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -88,6 +88,34 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( .TypeConstraint("T", BuildKernelDefConstraints()), MatMul); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 7, 8, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, 10, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 11, 12, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_TYPED_KERNEL( + MatMul, + 13, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + template Status MatMul::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); @@ -108,7 +136,12 @@ Status MatMul::Compute(OpKernelContext* ctx) const { // be filled out with zeros. EigenMatrixMapRowMajor dest(y->MutableData(), narrow(helper.M()), narrow(helper.N())); - dest.setZero(); + if constexpr (std::is_same::value) { + dest.setConstant(MLFloat16(0.0f)); + } else { + dest.setZero(); + } + return Status::OK(); } diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 135b4bb4c7..48663a9684 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -50,6 +50,30 @@ EIGEN_MATMUL_FUNCTION(uint32_t) EIGEN_MATMUL_FUNCTION(int64_t) EIGEN_MATMUL_FUNCTION(uint64_t) + + + +// template <> +// void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) { +// // // Convert MLFloat16* to Eigen::half* using reinterpret_cast +// // const Eigen::half* A_half = reinterpret_cast(A); +// // const Eigen::half* B_half = reinterpret_cast(B); +// // Eigen::half* C_half = reinterpret_cast(C); + +// // // Perform matrix multiplication using Eigen +// // auto C_mat = Eigen::Map>(C_half, M, N); +// // C_mat.noalias() = Eigen::Map>(A_half, M, K) * +// // Eigen::Map>(B_half, K, N); + +// // Optionally, handle threading with thread_pool if needed (not shown here). + +// math::Gemm(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast(&alpha), +// reinterpret_cast(A), reinterpret_cast(B), *reinterpret_cast(&beta), reinterpret_cast(C), thread_pool); +// } + +// template void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*); + + //////////////////////////////////////////////////////////////////////////////// // BLAS alternatives. // Depending on whether we have specified an external BLAS library or not, we @@ -185,6 +209,58 @@ void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool); } + +template <> +void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* a_data, const MLFloat16* b_data, MLFloat16* y_data, concurrency::ThreadPool* thread_pool) { + +MLFloat16 alpha = MLFloat16(1.0f); +MLFloat16 beta = MLFloat16(0.0f); + // if input is empty tensor, return directly as nothing need to be calculated. + if (M == 0 || N == 0) + return; + +#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#endif + +memset(&beta, 0, sizeof(MLFloat16)); +#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS) +#pragma GCC diagnostic pop +#endif +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + + + MLAS_HALF_GEMM_DATA_PARAMS data; + data.A = a_data; + data.lda = K; + data.B = b_data; + data.ldb = N; + data.C = y_data; + data.ldc = N; + // if (c_shape != nullptr) { + // data.Bias = c_data; + // } + MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool); + return; + +#endif + // Fallback to Eigen + // // Broadcast the bias as needed if bias is given + // GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data); +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + math::Gemm(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast(&alpha), + reinterpret_cast(a_data), reinterpret_cast(b_data), *reinterpret_cast(&beta), reinterpret_cast(y_data), thread_pool); +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +} + + #ifdef MLAS_SUPPORTS_GEMM_DOUBLE template <> void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, const double* B, double* C, ThreadPool* threadpool) {