@@ -377,19 +377,21 @@ static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose tran
377377}
378378
379379static inline void gemm (cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m,
380- std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer<half, 1 > &a,
381- std::int64_t lda, cl::sycl::buffer<half, 1 > &b, std::int64_t ldb, half beta,
382- cl::sycl::buffer<half, 1 > &c, std::int64_t ldc) {
380+ std::int64_t n, std::int64_t k, sycl::half alpha,
381+ cl::sycl::buffer<sycl::half, 1 > &a, std::int64_t lda,
382+ cl::sycl::buffer<sycl::half, 1 > &b, std::int64_t ldb, sycl::half beta,
383+ cl::sycl::buffer<sycl::half, 1 > &c, std::int64_t ldc) {
383384 gemm_precondition (queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
384385 detail::gemm (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta,
385386 c, ldc);
386387 gemm_postcondition (queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
387388}
388389
389390static inline void gemm (cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m,
390- std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer<half, 1 > &a,
391- std::int64_t lda, cl::sycl::buffer<half, 1 > &b, std::int64_t ldb,
392- float beta, cl::sycl::buffer<float , 1 > &c, std::int64_t ldc) {
391+ std::int64_t n, std::int64_t k, float alpha,
392+ cl::sycl::buffer<sycl::half, 1 > &a, std::int64_t lda,
393+ cl::sycl::buffer<sycl::half, 1 > &b, std::int64_t ldb, float beta,
394+ cl::sycl::buffer<float , 1 > &c, std::int64_t ldc) {
393395 gemm_precondition (queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
394396 detail::gemm (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta,
395397 c, ldc);
@@ -470,10 +472,11 @@ static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpos
470472}
471473
472474static inline void gemm_batch (cl::sycl::queue &queue, transpose transa, transpose transb,
473- std::int64_t m, std::int64_t n, std::int64_t k, half alpha,
474- cl::sycl::buffer<half, 1 > &a, std::int64_t lda, std::int64_t stride_a,
475- cl::sycl::buffer<half, 1 > &b, std::int64_t ldb, std::int64_t stride_b,
476- half beta, cl::sycl::buffer<half, 1 > &c, std::int64_t ldc,
475+ std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha,
476+ cl::sycl::buffer<sycl::half, 1 > &a, std::int64_t lda,
477+ std::int64_t stride_a, cl::sycl::buffer<sycl::half, 1 > &b,
478+ std::int64_t ldb, std::int64_t stride_b, sycl::half beta,
479+ cl::sycl::buffer<sycl::half, 1 > &c, std::int64_t ldc,
477480 std::int64_t stride_c, std::int64_t batch_size) {
478481 gemm_batch_precondition (queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb,
479482 stride_b, beta, c, ldc, stride_c, batch_size);
@@ -2509,9 +2512,10 @@ static inline cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, tra
25092512}
25102513
25112514static inline cl::sycl::event gemm (cl::sycl::queue &queue, transpose transa, transpose transb,
2512- std::int64_t m, std::int64_t n, std::int64_t k, half alpha,
2513- const half *a, std::int64_t lda, const half *b, std::int64_t ldb,
2514- half beta, half *c, std::int64_t ldc,
2515+ std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha,
2516+ const sycl::half *a, std::int64_t lda, const sycl::half *b,
2517+ std::int64_t ldb, sycl::half beta, sycl::half *c,
2518+ std::int64_t ldc,
25152519 const std::vector<cl::sycl::event> &dependencies = {}) {
25162520 gemm_precondition (queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
25172521 dependencies);
@@ -2524,8 +2528,8 @@ static inline cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, tra
25242528
25252529static inline cl::sycl::event gemm (cl::sycl::queue &queue, transpose transa, transpose transb,
25262530 std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
2527- const half *a, std::int64_t lda, const half *b, std:: int64_t ldb ,
2528- float beta, float *c, std::int64_t ldc,
2531+ const sycl:: half *a, std::int64_t lda, const sycl:: half *b,
2532+ std:: int64_t ldb, float beta, float *c, std::int64_t ldc,
25292533 const std::vector<cl::sycl::event> &dependencies = {}) {
25302534 gemm_precondition (queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
25312535 dependencies);
@@ -2618,9 +2622,9 @@ static inline cl::sycl::event gemm_batch(
26182622
26192623static inline cl::sycl::event gemm_batch (cl::sycl::queue &queue, transpose *transa,
26202624 transpose *transb, std::int64_t *m, std::int64_t *n,
2621- std::int64_t *k, half *alpha, const half **a,
2622- std::int64_t *lda, const half **b, std::int64_t *ldb,
2623- half *beta, half **c, std::int64_t *ldc,
2625+ std::int64_t *k, sycl:: half *alpha, const sycl:: half **a,
2626+ std::int64_t *lda, const sycl:: half **b, std::int64_t *ldb,
2627+ sycl:: half *beta, sycl:: half **c, std::int64_t *ldc,
26242628 std::int64_t group_count, std::int64_t *group_size,
26252629 const std::vector<cl::sycl::event> &dependencies = {}) {
26262630 gemm_batch_precondition (queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
@@ -2702,11 +2706,12 @@ static inline cl::sycl::event gemm_batch(
27022706}
27032707
27042708static inline cl::sycl::event gemm_batch (cl::sycl::queue &queue, transpose transa, transpose transb,
2705- std::int64_t m, std::int64_t n, std::int64_t k, half alpha,
2706- const half *a, std::int64_t lda, std::int64_t stride_a,
2707- const half *b, std::int64_t ldb, std::int64_t stride_b,
2708- half beta, half *c, std::int64_t ldc,
2709- std::int64_t stride_c, std::int64_t batch_size,
2709+ std::int64_t m, std::int64_t n, std::int64_t k,
2710+ sycl::half alpha, const sycl::half *a, std::int64_t lda,
2711+ std::int64_t stride_a, const sycl::half *b,
2712+ std::int64_t ldb, std::int64_t stride_b, sycl::half beta,
2713+ sycl::half *c, std::int64_t ldc, std::int64_t stride_c,
2714+ std::int64_t batch_size,
27102715 const std::vector<cl::sycl::event> &dependencies = {}) {
27112716 gemm_batch_precondition (queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb,
27122717 stride_b, beta, c, ldc, stride_c, batch_size, dependencies);
0 commit comments