@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8282 return device_type.str ();
8383}
8484
85+ template <typename Ts> struct matrix_info_t {
86+ oneapi::mkl::transpose transpose_info[2 ];
87+ Ts value_info[2 ];
88+ std::int64_t size_info[3 ];
89+ std::int64_t ld_info[3 ];
90+ std::int64_t groupsize_info;
91+ };
92+
8593namespace dpct
8694{
8795 typedef sycl::queue *queue_ptr;
@@ -1727,26 +1735,13 @@ namespace dpct
17271735 };
17281736
17291737 template <class Ta , class Tb , class Tc , class Ts >
1730- inline void gemm_batch_impl (sycl::queue &q, oneapi::mkl::transpose a_trans,
1731- oneapi::mkl::transpose b_trans, int m, int n, int k,
1732- const void *alpha, const void **a, int lda,
1733- const void **b, int ldb, const void *beta, void **c,
1734- int ldc, int batch_size)
1735- {
1736- struct matrix_info_t
1737- {
1738- oneapi::mkl::transpose transpose_info[2 ];
1739- Ts value_info[2 ];
1740- std::int64_t size_info[3 ];
1741- std::int64_t ld_info[3 ];
1742- std::int64_t groupsize_info;
1743- };
1744-
1738+ inline void gemm_batch_impl (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1739+ int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1740+ int ldb, const void * beta, void ** c, int ldc, int batch_size,
1741+ matrix_info_t <float > * matrix_info) {
17451742 Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
17461743 Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
17471744
1748- matrix_info_t *matrix_info =
1749- (matrix_info_t *)std::malloc (sizeof (matrix_info_t ));
17501745 matrix_info->transpose_info [0 ] = a_trans;
17511746 matrix_info->transpose_info [1 ] = b_trans;
17521747 matrix_info->value_info [0 ] = alpha_value;
@@ -1763,23 +1758,18 @@ namespace dpct
17631758 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17641759 oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
17651760 matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1766- matrix_info->size_info + 2 , matrix_info-> value_info , reinterpret_cast <const Ta **>(a ),
1767- matrix_info-> ld_info , reinterpret_cast <const Tb **>(b ), matrix_info->ld_info + 1 ,
1768- matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1769- &(matrix_info->groupsize_info ));
1761+ matrix_info->size_info + 2 , reinterpret_cast <Ts *>(matrix_info-> value_info ),
1762+ reinterpret_cast <const Ta **>(a ), matrix_info->ld_info , reinterpret_cast < const Tb **>(b) ,
1763+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>( matrix_info->value_info + 1 ) ,
1764+ reinterpret_cast <Tc **>(c), matrix_info-> ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
17701765#else
17711766 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17721767 q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1773- matrix_info->size_info + 1 , matrix_info->size_info + 2 , matrix_info->value_info ,
1768+ matrix_info->size_info + 1 , matrix_info->size_info + 2 , reinterpret_cast <Ts *>( matrix_info->value_info ) ,
17741769 reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1775- matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c ),
1776- matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1770+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>( matrix_info->value_info + 1 ),
1771+ reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
17771772#endif
1778-
1779- q.submit ([&](sycl::handler &cgh)
1780- {
1781- cgh.depends_on (e);
1782- cgh.host_task ([=] { std::free (matrix_info); }); });
17831773 }
17841774
17851775 template <class Ta , class Tb , class Tc , class Ts >
@@ -2422,25 +2412,11 @@ namespace dpct
24222412 // / \param [in] ldc Leading dimension of C.
24232413 // / \param [in] batch_size Specifies the number of matrix multiply operations to perform.
24242414 // / \param [in] scaling_type Data type of the scaling factors.
2425- inline void gemm_batch (sycl::queue &q, oneapi::mkl::transpose a_trans,
2426- oneapi::mkl::transpose b_trans, int m, int n, int k,
2427- const void *alpha, const void *a[],
2428- library_data_t a_type, int lda, const void *b[],
2429- library_data_t b_type, int ldb, const void *beta,
2430- void *c[], library_data_t c_type, int ldc,
2431- int batch_size, library_data_t scaling_type)
2432- {
2433- if (scaling_type == library_data_t ::real_float &&
2434- c_type == library_data_t ::complex_float)
2435- {
2436- scaling_type = library_data_t ::complex_float;
2437- }
2438- else if (scaling_type == library_data_t ::real_double &&
2439- c_type == library_data_t ::complex_double)
2440- {
2441- scaling_type = library_data_t ::complex_double;
2442- }
2443-
2415+ inline void gemm_batch (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2416+ int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2417+ const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2418+ library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2419+ matrix_info_t <float > * matrix_info) {
24442420 std::uint64_t key =
24452421 detail::get_type_combination_id (a_type, b_type, c_type, scaling_type);
24462422 switch (key)
@@ -2449,68 +2425,41 @@ namespace dpct
24492425 library_data_t ::real_float, library_data_t ::real_float,
24502426 library_data_t ::real_float, library_data_t ::real_float):
24512427 {
2452- detail::gemm_batch_impl<float , float , float , float >(
2453- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2454- batch_size);
2428+ detail::gemm_batch_impl<float , float , float , float >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2429+ beta, c, ldc, batch_size, matrix_info);
24552430 break ;
24562431 }
24572432 case detail::get_type_combination_id (
24582433 library_data_t ::real_double, library_data_t ::real_double,
24592434 library_data_t ::real_double, library_data_t ::real_double):
24602435 {
2461- detail::gemm_batch_impl<double , double , double , double >(
2462- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2463- batch_size);
2464- break ;
2465- }
2466- case detail::get_type_combination_id (
2467- library_data_t ::complex_float, library_data_t ::complex_float,
2468- library_data_t ::complex_float, library_data_t ::complex_float):
2469- {
2470- detail::gemm_batch_impl<std::complex <float >, std::complex <float >,
2471- std::complex <float >, std::complex <float >>(
2472- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2473- batch_size);
2474- break ;
2475- }
2476- case detail::get_type_combination_id (
2477- library_data_t ::complex_double, library_data_t ::complex_double,
2478- library_data_t ::complex_double, library_data_t ::complex_double):
2479- {
2480- detail::gemm_batch_impl<std::complex <double >, std::complex <double >,
2481- std::complex <double >, std::complex <double >>(
2482- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2483- batch_size);
2436+ detail::gemm_batch_impl<double , double , double , double >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2437+ beta, c, ldc, batch_size, matrix_info);
24842438 break ;
24852439 }
24862440 case detail::get_type_combination_id (
24872441 library_data_t ::real_half, library_data_t ::real_half,
24882442 library_data_t ::real_half, library_data_t ::real_half):
24892443 {
2490- detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2491- sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2492- a, lda, b, ldb, beta, c, ldc,
2493- batch_size);
2444+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2445+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24942446 break ;
24952447 }
24962448#ifdef __INTEL_MKL__
24972449 case detail::get_type_combination_id (
24982450 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
24992451 library_data_t ::real_bfloat16, library_data_t ::real_float):
25002452 {
2501- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2502- oneapi::mkl::bfloat16, float >(
2503- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2504- batch_size);
2453+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float >(
2454+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25052455 break ;
25062456 }
25072457 case detail::get_type_combination_id (
25082458 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
25092459 library_data_t ::real_float, library_data_t ::real_float):
25102460 {
2511- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float ,
2512- float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2513- b, ldb, beta, c, ldc, batch_size);
2461+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float , float >(
2462+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25142463 break ;
25152464 }
25162465#endif
@@ -2522,28 +2471,25 @@ namespace dpct
25222471 dpct::get_value (reinterpret_cast <const std::int32_t *>(alpha), q);
25232472 float beta_float =
25242473 dpct::get_value (reinterpret_cast <const std::int32_t *>(beta), q);
2525- detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t ,
2526- float >(q, a_trans, b_trans, m, n, k, &alpha_float,
2527- a, lda, b, ldb, &beta_float, c, ldc,
2528- batch_size);
2474+ detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t , float >(
2475+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2476+ matrix_info);
25292477 break ;
25302478 }
25312479 case detail::get_type_combination_id (
25322480 library_data_t ::real_int8, library_data_t ::real_int8,
25332481 library_data_t ::real_float, library_data_t ::real_float):
25342482 {
25352483 detail::gemm_batch_impl<std::int8_t , std::int8_t , float , float >(
2536- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2537- batch_size);
2484+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25382485 break ;
25392486 }
25402487 case detail::get_type_combination_id (
25412488 library_data_t ::real_half, library_data_t ::real_half,
25422489 library_data_t ::real_float, library_data_t ::real_float):
25432490 {
25442491 detail::gemm_batch_impl<sycl::half, sycl::half, float , float >(
2545- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2546- batch_size);
2492+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25472493 break ;
25482494 }
25492495 case detail::get_type_combination_id (
@@ -2557,8 +2503,7 @@ namespace dpct
25572503 sycl::half alpha_half (alpha_value);
25582504 sycl::half beta_half (beta_value);
25592505 detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2560- q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2561- batch_size);
2506+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
25622507 break ;
25632508 }
25642509 default :
0 commit comments