Skip to content

Commit 8a33ae0

Browse files
committed
registering float16 matmul. wip
1 parent 200f991 commit 8a33ae0

File tree

4 files changed

+67
-3
lines changed

4 files changed

+67
-3
lines changed

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
146146
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
147147
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
148148
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm);
149+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, MatMul);
149150
#endif
150151
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax);
151152
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax);
@@ -349,6 +350,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
349350
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm);
350351
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
351352
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm);
353+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, MatMul);
352354
#endif
353355
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul);
354356
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul);
@@ -522,6 +524,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
522524
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm);
523525
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
524526
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm);
527+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, MatMul);
525528
#endif
526529
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
527530
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift);
@@ -631,6 +634,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
631634
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm);
632635
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
633636
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm);
637+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul);
634638
#endif
635639
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
636640
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
@@ -2839,6 +2843,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
28392843
return Status::OK();
28402844
}
28412845

2846+
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
2847+
#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is defined")
2848+
#else
2849+
#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is NOT defined")
2850+
#endif
2851+
2852+
28422853
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
28432854
Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
28442855
static const BuildKernelCreateInfoFn function_table[] = {
@@ -2870,10 +2881,17 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
28702881
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
28712882
MLFloat16, Gemm)>,
28722883
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
2873-
MLFloat16, Gemm)>,
2874-
2884+
MLFloat16, Gemm)>,
28752885
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
28762886
Gemm)>,
2887+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8,
2888+
MLFloat16, MatMul)>,
2889+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
2890+
MLFloat16, MatMul)>,
2891+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
2892+
MLFloat16, MatMul)>,
2893+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
2894+
MatMul)>
28772895
};
28782896

28792897
for (auto& function_table_entry : function_table) {
@@ -3125,6 +3143,7 @@ Status RegisterCPUKernels(KernelRegistry& kernel_registry) {
31253143
ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry));
31263144
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
31273145
if (MlasFp16AccelerationSupported()) {
3146+
#pragma message("calling RegisterFp16Kernels")
31283147
ORT_RETURN_IF_ERROR(RegisterFp16Kernels(kernel_registry));
31293148
}
31303149
#endif

onnxruntime/core/providers/cpu/math/matmul.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,34 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
8888
.TypeConstraint("T", BuildKernelDefConstraints<int64_t, uint64_t>()),
8989
MatMul<int64_t>);
9090

91+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
92+
MatMul,
93+
7, 8,
94+
MLFloat16,
95+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
96+
MatMul<MLFloat16>);
97+
98+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
99+
MatMul,
100+
9, 10,
101+
MLFloat16,
102+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
103+
MatMul<MLFloat16>);
104+
105+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
106+
MatMul,
107+
11, 12,
108+
MLFloat16,
109+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
110+
MatMul<MLFloat16>);
111+
112+
ONNX_CPU_OPERATOR_TYPED_KERNEL(
113+
MatMul,
114+
13,
115+
MLFloat16,
116+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
117+
MatMul<MLFloat16>);
118+
91119
template <typename T>
92120
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
93121
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

onnxruntime/core/util/math_cpu.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,23 @@ EIGEN_MATMUL_FUNCTION(uint32_t)
5050
EIGEN_MATMUL_FUNCTION(int64_t)
5151
EIGEN_MATMUL_FUNCTION(uint64_t)
5252

53+
54+
template <>
55+
void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*) {
56+
// Convert MLFloat16* to Eigen::half* using reinterpret_cast
57+
const Eigen::half* A_half = reinterpret_cast<const Eigen::half*>(A);
58+
const Eigen::half* B_half = reinterpret_cast<const Eigen::half*>(B);
59+
Eigen::half* C_half = reinterpret_cast<Eigen::half*>(C);
60+
61+
// Perform matrix multiplication using Eigen
62+
auto C_mat = Eigen::Map<Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(C_half, N, M);
63+
C_mat.noalias() = Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(B_half, N, K) *
64+
Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(A_half, K, M);
65+
}
66+
67+
// template void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*);
68+
69+
5370
////////////////////////////////////////////////////////////////////////////////
5471
// BLAS alternatives.
5572
// Depending on whether we have specified an external BLAS library or not, we

onnxruntime/test/providers/cpu/math/matmul_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ TEST(MathOpTest, MatMul_Float16) {
311311
run_test(true);
312312
run_test(false);
313313
}
314-
#endif
314+
// #endif
315315

316316
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL)
317317
TEST(MathOpTest, MatMul_bfloat16) {

0 commit comments

Comments
 (0)