Skip to content

Commit c2c53b7

Browse files
committed
bypassed optimization pass which introduced casts to fp32 .
also added faster implementation of fp16 MatMul
1 parent d6d981e commit c2c53b7

File tree

2 files changed

+55
-15
lines changed

2 files changed

+55
-15
lines changed

onnxruntime/core/optimizer/insert_cast_transformer.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,12 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
221221
}
222222

223223
static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
224-
for (auto& node : graph.Nodes()) {
225-
if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) {
226-
// unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
227-
node.SetExecutionProviderType("");
228-
}
229-
}
224+
// for (auto& node : graph.Nodes()) {
225+
// if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) {
226+
// // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
227+
// node.SetExecutionProviderType("");
228+
// }
229+
// }
230230

231231
return Status::OK();
232232
}

onnxruntime/core/util/math_cpu.cc

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,58 @@ void MatMul<float>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const
209209
MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool);
210210
}
211211

212+
212213
template <>
213-
void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) {
214-
// Set alpha to 1 and beta to 0
215-
Eigen::half alpha = Eigen::half(1.0f);
216-
Eigen::half beta = Eigen::half(0.0f);
217-
218-
// Use GEMM with the given parameters
219-
math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, alpha,
220-
reinterpret_cast<const Eigen::half*>(A), reinterpret_cast<const Eigen::half*>(B), beta,
221-
reinterpret_cast<Eigen::half*>(C), thread_pool);
214+
void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* a_data, const MLFloat16* b_data, MLFloat16* y_data, concurrency::ThreadPool* thread_pool) {
215+
216+
MLFloat16 alpha = MLFloat16(1.0f);
217+
MLFloat16 beta = MLFloat16(0.0f);
218+
// if input is empty tensor, return directly as nothing need to be calculated.
219+
if (M == 0 || N == 0)
220+
return;
221+
222+
#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
223+
#pragma GCC diagnostic push
224+
#pragma GCC diagnostic ignored "-Wclass-memaccess"
225+
#endif
226+
227+
memset(&beta, 0, sizeof(MLFloat16));
228+
#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
229+
#pragma GCC diagnostic pop
230+
#endif
231+
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
232+
233+
234+
MLAS_HALF_GEMM_DATA_PARAMS data;
235+
data.A = a_data;
236+
data.lda = K;
237+
data.B = b_data;
238+
data.ldb = N;
239+
data.C = y_data;
240+
data.ldc = N;
241+
// if (c_shape != nullptr) {
242+
// data.Bias = c_data;
243+
// }
244+
MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool);
245+
return;
246+
247+
#endif
248+
// Fallback to Eigen
249+
// // Broadcast the bias as needed if bias is given
250+
// GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);
251+
#if defined(__GNUC__)
252+
#pragma GCC diagnostic push
253+
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
254+
#endif
255+
math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast<Eigen::half*>(&alpha),
256+
reinterpret_cast<const Eigen::half*>(a_data), reinterpret_cast<const Eigen::half*>(b_data), *reinterpret_cast<Eigen::half*>(&beta), reinterpret_cast<Eigen::half*>(y_data), thread_pool);
257+
#if defined(__GNUC__)
258+
#pragma GCC diagnostic pop
259+
#endif
260+
222261
}
223262

263+
224264
#ifdef MLAS_SUPPORTS_GEMM_DOUBLE
225265
template <>
226266
void MatMul<double>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, const double* B, double* C, ThreadPool* threadpool) {

0 commit comments

Comments
 (0)