@@ -51,18 +51,25 @@ EIGEN_MATMUL_FUNCTION(int64_t)
5151EIGEN_MATMUL_FUNCTION (uint64_t )
5252
5353
54- template <>
55- void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*) {
56- // Convert MLFloat16* to Eigen::half* using reinterpret_cast
57- const Eigen::half* A_half = reinterpret_cast <const Eigen::half*>(A);
58- const Eigen::half* B_half = reinterpret_cast <const Eigen::half*>(B);
59- Eigen::half* C_half = reinterpret_cast <Eigen::half*>(C);
60-
61- // Perform matrix multiplication using Eigen
62- auto C_mat = Eigen::Map<Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(C_half, N, M);
63- C_mat.noalias () = Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(B_half, N, K) *
64- Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(A_half, K, M);
65- }
54+
55+
56+ // template <>
57+ // void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) {
58+ // // // Convert MLFloat16* to Eigen::half* using reinterpret_cast
59+ // // const Eigen::half* A_half = reinterpret_cast<const Eigen::half*>(A);
60+ // // const Eigen::half* B_half = reinterpret_cast<const Eigen::half*>(B);
61+ // // Eigen::half* C_half = reinterpret_cast<Eigen::half*>(C);
62+
63+ // // // Perform matrix multiplication using Eigen
64+ // // auto C_mat = Eigen::Map<Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>(C_half, M, N);
65+ // // C_mat.noalias() = Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>(A_half, M, K) *
66+ // // Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>(B_half, K, N);
67+
68+ // // Optionally, handle threading with thread_pool if needed (not shown here).
69+
70+ // math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast<Eigen::half*>(&alpha),
71+ // reinterpret_cast<const Eigen::half*>(A), reinterpret_cast<const Eigen::half*>(B), *reinterpret_cast<Eigen::half*>(&beta), reinterpret_cast<Eigen::half*>(C), thread_pool);
72+ // }
6673
6774// template void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*);
6875
@@ -202,6 +209,18 @@ void MatMul<float>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const
202209 MlasGemm (CblasNoTrans, CblasNoTrans, M, N, K, 1 .f , A, K, B, N, 0 .f , C, N, threadpool);
203210}
204211
212+ template <>
213+ void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) {
214+ // Set alpha to 1 and beta to 0
215+ Eigen::half alpha = Eigen::half (1 .0f );
216+ Eigen::half beta = Eigen::half (0 .0f );
217+
218+ // Use GEMM with the given parameters
219+ math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, alpha,
220+ reinterpret_cast <const Eigen::half*>(A), reinterpret_cast <const Eigen::half*>(B), beta,
221+ reinterpret_cast <Eigen::half*>(C), thread_pool);
222+ }
223+
205224#ifdef MLAS_SUPPORTS_GEMM_DOUBLE
206225template <>
207226void MatMul<double >(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double * A, const double * B, double * C, ThreadPool* threadpool) {
0 commit comments