Skip to content

Commit d6d981e

Browse files
committed
fixed validation issue
1 parent 8a33ae0 commit d6d981e

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

onnxruntime/core/util/math_cpu.cc

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,25 @@ EIGEN_MATMUL_FUNCTION(int64_t)
5151
EIGEN_MATMUL_FUNCTION(uint64_t)
5252

5353

54-
template <>
55-
void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*) {
56-
// Convert MLFloat16* to Eigen::half* using reinterpret_cast
57-
const Eigen::half* A_half = reinterpret_cast<const Eigen::half*>(A);
58-
const Eigen::half* B_half = reinterpret_cast<const Eigen::half*>(B);
59-
Eigen::half* C_half = reinterpret_cast<Eigen::half*>(C);
60-
61-
// Perform matrix multiplication using Eigen
62-
auto C_mat = Eigen::Map<Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(C_half, N, M);
63-
C_mat.noalias() = Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(B_half, N, K) *
64-
Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(A_half, K, M);
65-
}
54+
55+
56+
// template <>
57+
// void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) {
58+
// // // Convert MLFloat16* to Eigen::half* using reinterpret_cast
59+
// // const Eigen::half* A_half = reinterpret_cast<const Eigen::half*>(A);
60+
// // const Eigen::half* B_half = reinterpret_cast<const Eigen::half*>(B);
61+
// // Eigen::half* C_half = reinterpret_cast<Eigen::half*>(C);
62+
63+
// // // Perform matrix multiplication using Eigen
64+
// // auto C_mat = Eigen::Map<Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>(C_half, M, N);
65+
// // C_mat.noalias() = Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>(A_half, M, K) *
66+
// // Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>(B_half, K, N);
67+
68+
// // Optionally, handle threading with thread_pool if needed (not shown here).
69+
70+
// math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast<Eigen::half*>(&alpha),
71+
// reinterpret_cast<const Eigen::half*>(A), reinterpret_cast<const Eigen::half*>(B), *reinterpret_cast<Eigen::half*>(&beta), reinterpret_cast<Eigen::half*>(C), thread_pool);
72+
// }
6673

6774
// template void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*);
6875

@@ -202,6 +209,18 @@ void MatMul<float>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const
202209
MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool);
203210
}
204211

212+
template <>
213+
void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) {
214+
// Set alpha to 1 and beta to 0
215+
Eigen::half alpha = Eigen::half(1.0f);
216+
Eigen::half beta = Eigen::half(0.0f);
217+
218+
// Use GEMM with the given parameters
219+
math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, alpha,
220+
reinterpret_cast<const Eigen::half*>(A), reinterpret_cast<const Eigen::half*>(B), beta,
221+
reinterpret_cast<Eigen::half*>(C), thread_pool);
222+
}
223+
205224
#ifdef MLAS_SUPPORTS_GEMM_DOUBLE
206225
template <>
207226
void MatMul<double>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, const double* B, double* C, ThreadPool* threadpool) {

0 commit comments

Comments
 (0)