Skip to content

Commit d4b80e6

Browse files
committed
registering float16 matmul. wip
1 parent 44d9f6c commit d4b80e6

File tree

4 files changed

+75
-2
lines changed

4 files changed

+75
-2
lines changed

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Aco
144144
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan);
145145
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
146146
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
147+
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
148+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, MatMul);
149+
#endif
147150
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax);
148151
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax);
149152
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax);
@@ -344,6 +347,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
344347
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
345348
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
346349
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm);
350+
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
351+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, MatMul);
352+
#endif
347353
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul);
348354
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul);
349355
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul);
@@ -514,6 +520,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Sp
514520
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND);
515521
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
516522
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm);
523+
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
524+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, MatMul);
525+
#endif
517526
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
518527
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift);
519528
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift);
@@ -620,6 +629,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
620629
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand);
621630
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm);
622631
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm);
632+
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
633+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul);
634+
#endif
623635
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
624636
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
625637
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul);
@@ -2814,7 +2826,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
28142826

28152827

28162828
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kQuadricDomain, 1, QuadricCustomOp)>,
2817-
2829+
28182830
};
28192831

28202832
for (auto& function_table_entry : function_table) {
@@ -2827,6 +2839,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
28272839
return Status::OK();
28282840
}
28292841

2842+
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
2843+
#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is defined")
2844+
#else
2845+
#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is NOT defined")
2846+
#endif
2847+
2848+
28302849
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
28312850
Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
28322851
static const BuildKernelCreateInfoFn function_table[] = {
@@ -2853,6 +2872,14 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
28532872
MLFloat16, LeakyRelu)>,
28542873
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, MLFloat16,
28552874
LeakyRelu)>,
2875+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8,
2876+
MLFloat16, MatMul)>,
2877+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
2878+
MLFloat16, MatMul)>,
2879+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
2880+
MLFloat16, MatMul)>,
2881+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
2882+
MatMul)>
28562883
};
28572884

28582885
for (auto& function_table_entry : function_table) {
@@ -3104,6 +3131,7 @@ Status RegisterCPUKernels(KernelRegistry& kernel_registry) {
31043131
ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry));
31053132
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
31063133
if (MlasFp16AccelerationSupported()) {
3134+
#pragma message("calling RegisterFp16Kernels")
31073135
ORT_RETURN_IF_ERROR(RegisterFp16Kernels(kernel_registry));
31083136
}
31093137
#endif

onnxruntime/core/providers/cpu/math/matmul.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,34 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
8888
.TypeConstraint("T", BuildKernelDefConstraints<int64_t, uint64_t>()),
8989
MatMul<int64_t>);
9090

91+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
92+
MatMul,
93+
7, 8,
94+
MLFloat16,
95+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
96+
MatMul<MLFloat16>);
97+
98+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
99+
MatMul,
100+
9, 10,
101+
MLFloat16,
102+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
103+
MatMul<MLFloat16>);
104+
105+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
106+
MatMul,
107+
11, 12,
108+
MLFloat16,
109+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
110+
MatMul<MLFloat16>);
111+
112+
ONNX_CPU_OPERATOR_TYPED_KERNEL(
113+
MatMul,
114+
13,
115+
MLFloat16,
116+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
117+
MatMul<MLFloat16>);
118+
91119
template <typename T>
92120
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
93121
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

onnxruntime/core/util/math_cpu.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,23 @@ EIGEN_MATMUL_FUNCTION(uint32_t)
5050
EIGEN_MATMUL_FUNCTION(int64_t)
5151
EIGEN_MATMUL_FUNCTION(uint64_t)
5252

53+
54+
template <>
55+
void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*) {
56+
// Convert MLFloat16* to Eigen::half* using reinterpret_cast
57+
const Eigen::half* A_half = reinterpret_cast<const Eigen::half*>(A);
58+
const Eigen::half* B_half = reinterpret_cast<const Eigen::half*>(B);
59+
Eigen::half* C_half = reinterpret_cast<Eigen::half*>(C);
60+
61+
// Perform matrix multiplication using Eigen
62+
auto C_mat = Eigen::Map<Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(C_half, N, M);
63+
C_mat.noalias() = Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(B_half, N, K) *
64+
Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(A_half, K, M);
65+
}
66+
67+
// template void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*);
68+
69+
5370
////////////////////////////////////////////////////////////////////////////////
5471
// BLAS alternatives.
5572
// Depending on whether we have specified an external BLAS library or not, we

onnxruntime/test/providers/cpu/math/matmul_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ TEST(MathOpTest, MatMul_Float16) {
311311
run_test(true);
312312
run_test(false);
313313
}
314-
#endif
314+
// #endif
315315

316316
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL)
317317
TEST(MathOpTest, MatMul_bfloat16) {

0 commit comments

Comments
 (0)