@@ -209,18 +209,58 @@ void MatMul<float>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const
209209 MlasGemm (CblasNoTrans, CblasNoTrans, M, N, K, 1 .f , A, K, B, N, 0 .f , C, N, threadpool);
210210}
211211
212+
212213template <>
213- void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) {
214- // Set alpha to 1 and beta to 0
215- Eigen::half alpha = Eigen::half (1 .0f );
216- Eigen::half beta = Eigen::half (0 .0f );
217-
218- // Use GEMM with the given parameters
219- math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, alpha,
220- reinterpret_cast <const Eigen::half*>(A), reinterpret_cast <const Eigen::half*>(B), beta,
221- reinterpret_cast <Eigen::half*>(C), thread_pool);
214+ void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* a_data, const MLFloat16* b_data, MLFloat16* y_data, concurrency::ThreadPool* thread_pool) {
215+
216+ MLFloat16 alpha = MLFloat16 (1 .0f );
217+ MLFloat16 beta = MLFloat16 (0 .0f );
218+ // if input is empty tensor, return directly as nothing need to be calculated.
219+ if (M == 0 || N == 0 )
220+ return ;
221+
222+ #if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
223+ #pragma GCC diagnostic push
224+ #pragma GCC diagnostic ignored "-Wclass-memaccess"
225+ #endif
226+
227+ memset (&beta, 0 , sizeof (MLFloat16));
228+ #if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
229+ #pragma GCC diagnostic pop
230+ #endif
231+ #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
232+
233+
234+ MLAS_HALF_GEMM_DATA_PARAMS data;
235+ data.A = a_data;
236+ data.lda = K;
237+ data.B = b_data;
238+ data.ldb = N;
239+ data.C = y_data;
240+ data.ldc = N;
241+ // if (c_shape != nullptr) {
242+ // data.Bias = c_data;
243+ // }
244+ MlasHalfGemmBatch (M, N, K, 1 , &data, thread_pool);
245+ return ;
246+
247+ #endif
248+ // Fallback to Eigen
249+ // // Broadcast the bias as needed if bias is given
250+ // GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);
251+ #if defined(__GNUC__)
252+ #pragma GCC diagnostic push
253+ #pragma GCC diagnostic ignored "-Wstrict-aliasing"
254+ #endif
255+ math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast <Eigen::half*>(&alpha),
256+ reinterpret_cast <const Eigen::half*>(a_data), reinterpret_cast <const Eigen::half*>(b_data), *reinterpret_cast <Eigen::half*>(&beta), reinterpret_cast <Eigen::half*>(y_data), thread_pool);
257+ #if defined(__GNUC__)
258+ #pragma GCC diagnostic pop
259+ #endif
260+
222261}
223262
263+
224264#ifdef MLAS_SUPPORTS_GEMM_DOUBLE
225265template <>
226266void MatMul<double >(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double * A, const double * B, double * C, ThreadPool* threadpool) {
0 commit comments