Skip to content

Commit 1c8c47a

Browse files
Replace half with sycl::half to align with SYCL standard (#143)
* [BLAS] replace half with sycl::half
1 parent 69af853 commit 1c8c47a

32 files changed

+548
-503
lines changed

include/oneapi/mkl/blas.hxx

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -377,19 +377,21 @@ static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose tran
377377
}
378378

379379
static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m,
380-
std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer<half, 1> &a,
381-
std::int64_t lda, cl::sycl::buffer<half, 1> &b, std::int64_t ldb, half beta,
382-
cl::sycl::buffer<half, 1> &c, std::int64_t ldc) {
380+
std::int64_t n, std::int64_t k, sycl::half alpha,
381+
cl::sycl::buffer<sycl::half, 1> &a, std::int64_t lda,
382+
cl::sycl::buffer<sycl::half, 1> &b, std::int64_t ldb, sycl::half beta,
383+
cl::sycl::buffer<sycl::half, 1> &c, std::int64_t ldc) {
383384
gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
384385
detail::gemm(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta,
385386
c, ldc);
386387
gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
387388
}
388389

389390
static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m,
390-
std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer<half, 1> &a,
391-
std::int64_t lda, cl::sycl::buffer<half, 1> &b, std::int64_t ldb,
392-
float beta, cl::sycl::buffer<float, 1> &c, std::int64_t ldc) {
391+
std::int64_t n, std::int64_t k, float alpha,
392+
cl::sycl::buffer<sycl::half, 1> &a, std::int64_t lda,
393+
cl::sycl::buffer<sycl::half, 1> &b, std::int64_t ldb, float beta,
394+
cl::sycl::buffer<float, 1> &c, std::int64_t ldc) {
393395
gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
394396
detail::gemm(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta,
395397
c, ldc);
@@ -470,10 +472,11 @@ static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpos
470472
}
471473

472474
static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb,
473-
std::int64_t m, std::int64_t n, std::int64_t k, half alpha,
474-
cl::sycl::buffer<half, 1> &a, std::int64_t lda, std::int64_t stride_a,
475-
cl::sycl::buffer<half, 1> &b, std::int64_t ldb, std::int64_t stride_b,
476-
half beta, cl::sycl::buffer<half, 1> &c, std::int64_t ldc,
475+
std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha,
476+
cl::sycl::buffer<sycl::half, 1> &a, std::int64_t lda,
477+
std::int64_t stride_a, cl::sycl::buffer<sycl::half, 1> &b,
478+
std::int64_t ldb, std::int64_t stride_b, sycl::half beta,
479+
cl::sycl::buffer<sycl::half, 1> &c, std::int64_t ldc,
477480
std::int64_t stride_c, std::int64_t batch_size) {
478481
gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb,
479482
stride_b, beta, c, ldc, stride_c, batch_size);
@@ -2509,9 +2512,10 @@ static inline cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, tra
25092512
}
25102513

25112514
static inline cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb,
2512-
std::int64_t m, std::int64_t n, std::int64_t k, half alpha,
2513-
const half *a, std::int64_t lda, const half *b, std::int64_t ldb,
2514-
half beta, half *c, std::int64_t ldc,
2515+
std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha,
2516+
const sycl::half *a, std::int64_t lda, const sycl::half *b,
2517+
std::int64_t ldb, sycl::half beta, sycl::half *c,
2518+
std::int64_t ldc,
25152519
const std::vector<cl::sycl::event> &dependencies = {}) {
25162520
gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
25172521
dependencies);
@@ -2524,8 +2528,8 @@ static inline cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, tra
25242528

25252529
static inline cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb,
25262530
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
2527-
const half *a, std::int64_t lda, const half *b, std::int64_t ldb,
2528-
float beta, float *c, std::int64_t ldc,
2531+
const sycl::half *a, std::int64_t lda, const sycl::half *b,
2532+
std::int64_t ldb, float beta, float *c, std::int64_t ldc,
25292533
const std::vector<cl::sycl::event> &dependencies = {}) {
25302534
gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
25312535
dependencies);
@@ -2618,9 +2622,9 @@ static inline cl::sycl::event gemm_batch(
26182622

26192623
static inline cl::sycl::event gemm_batch(cl::sycl::queue &queue, transpose *transa,
26202624
transpose *transb, std::int64_t *m, std::int64_t *n,
2621-
std::int64_t *k, half *alpha, const half **a,
2622-
std::int64_t *lda, const half **b, std::int64_t *ldb,
2623-
half *beta, half **c, std::int64_t *ldc,
2625+
std::int64_t *k, sycl::half *alpha, const sycl::half **a,
2626+
std::int64_t *lda, const sycl::half **b, std::int64_t *ldb,
2627+
sycl::half *beta, sycl::half **c, std::int64_t *ldc,
26242628
std::int64_t group_count, std::int64_t *group_size,
26252629
const std::vector<cl::sycl::event> &dependencies = {}) {
26262630
gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
@@ -2702,11 +2706,12 @@ static inline cl::sycl::event gemm_batch(
27022706
}
27032707

27042708
static inline cl::sycl::event gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb,
2705-
std::int64_t m, std::int64_t n, std::int64_t k, half alpha,
2706-
const half *a, std::int64_t lda, std::int64_t stride_a,
2707-
const half *b, std::int64_t ldb, std::int64_t stride_b,
2708-
half beta, half *c, std::int64_t ldc,
2709-
std::int64_t stride_c, std::int64_t batch_size,
2709+
std::int64_t m, std::int64_t n, std::int64_t k,
2710+
sycl::half alpha, const sycl::half *a, std::int64_t lda,
2711+
std::int64_t stride_a, const sycl::half *b,
2712+
std::int64_t ldb, std::int64_t stride_b, sycl::half beta,
2713+
sycl::half *c, std::int64_t ldc, std::int64_t stride_c,
2714+
std::int64_t batch_size,
27102715
const std::vector<cl::sycl::event> &dependencies = {}) {
27112716
gemm_batch_precondition(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb,
27122717
stride_b, beta, c, ldc, stride_c, batch_size, dependencies);

include/oneapi/mkl/blas/detail/blas_ct_backends.hxx

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -457,10 +457,12 @@ static inline void gemm_batch(backend_selector<backend::BACKEND> selector, trans
457457

458458
static inline void gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
459459
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
460-
half alpha, cl::sycl::buffer<half, 1> &a, std::int64_t lda,
461-
std::int64_t stride_a, cl::sycl::buffer<half, 1> &b, std::int64_t ldb,
462-
std::int64_t stride_b, half beta, cl::sycl::buffer<half, 1> &c,
463-
std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size);
460+
sycl::half alpha, cl::sycl::buffer<sycl::half, 1> &a,
461+
std::int64_t lda, std::int64_t stride_a,
462+
cl::sycl::buffer<sycl::half, 1> &b, std::int64_t ldb,
463+
std::int64_t stride_b, sycl::half beta,
464+
cl::sycl::buffer<sycl::half, 1> &c, std::int64_t ldc,
465+
std::int64_t stride_c, std::int64_t batch_size);
464466

465467
static inline void spmv(backend_selector<backend::BACKEND> selector, uplo upper_lower,
466468
std::int64_t n, float alpha, cl::sycl::buffer<float, 1> &a,
@@ -576,14 +578,14 @@ static inline void gemm(backend_selector<backend::BACKEND> selector, transpose t
576578

577579
static inline void gemm(backend_selector<backend::BACKEND> selector, transpose transa,
578580
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
579-
half alpha, cl::sycl::buffer<half, 1> &a, std::int64_t lda,
580-
cl::sycl::buffer<half, 1> &b, std::int64_t ldb, half beta,
581-
cl::sycl::buffer<half, 1> &c, std::int64_t ldc);
581+
sycl::half alpha, cl::sycl::buffer<sycl::half, 1> &a, std::int64_t lda,
582+
cl::sycl::buffer<sycl::half, 1> &b, std::int64_t ldb, sycl::half beta,
583+
cl::sycl::buffer<sycl::half, 1> &c, std::int64_t ldc);
582584

583585
static inline void gemm(backend_selector<backend::BACKEND> selector, transpose transa,
584586
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
585-
float alpha, cl::sycl::buffer<half, 1> &a, std::int64_t lda,
586-
cl::sycl::buffer<half, 1> &b, std::int64_t ldb, float beta,
587+
float alpha, cl::sycl::buffer<sycl::half, 1> &a, std::int64_t lda,
588+
cl::sycl::buffer<sycl::half, 1> &b, std::int64_t ldb, float beta,
587589
cl::sycl::buffer<float, 1> &c, std::int64_t ldc);
588590

589591
static inline void gemm(backend_selector<backend::BACKEND> selector, transpose transa,
@@ -1700,9 +1702,10 @@ static inline cl::sycl::event gemm_batch(
17001702

17011703
static inline cl::sycl::event gemm_batch(backend_selector<backend::BACKEND> selector,
17021704
transpose *transa, transpose *transb, std::int64_t *m,
1703-
std::int64_t *n, std::int64_t *k, half *alpha,
1704-
const half **a, std::int64_t *lda, const half **b,
1705-
std::int64_t *ldb, half *beta, half **c, std::int64_t *ldc,
1705+
std::int64_t *n, std::int64_t *k, sycl::half *alpha,
1706+
const sycl::half **a, std::int64_t *lda,
1707+
const sycl::half **b, std::int64_t *ldb, sycl::half *beta,
1708+
sycl::half **c, std::int64_t *ldc,
17061709
std::int64_t group_count, std::int64_t *group_size,
17071710
const std::vector<cl::sycl::event> &dependencies = {});
17081711

@@ -1740,14 +1743,12 @@ static inline cl::sycl::event gemm_batch(
17401743
std::int64_t stride_c, std::int64_t batch_size,
17411744
const std::vector<cl::sycl::event> &dependencies = {});
17421745

1743-
static inline cl::sycl::event gemm_batch(backend_selector<backend::BACKEND> selector,
1744-
transpose transa, transpose transb, std::int64_t m,
1745-
std::int64_t n, std::int64_t k, half alpha, const half *a,
1746-
std::int64_t lda, std::int64_t stride_a, const half *b,
1747-
std::int64_t ldb, std::int64_t stride_b, half beta,
1748-
half *c, std::int64_t ldc, std::int64_t stride_c,
1749-
std::int64_t batch_size,
1750-
const std::vector<cl::sycl::event> &dependencies = {});
1746+
static inline cl::sycl::event gemm_batch(
1747+
backend_selector<backend::BACKEND> selector, transpose transa, transpose transb, std::int64_t m,
1748+
std::int64_t n, std::int64_t k, sycl::half alpha, const sycl::half *a, std::int64_t lda,
1749+
std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, std::int64_t stride_b,
1750+
sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c,
1751+
std::int64_t batch_size, const std::vector<cl::sycl::event> &dependencies = {});
17511752

17521753
static inline cl::sycl::event spmv(backend_selector<backend::BACKEND> selector, uplo upper_lower,
17531754
std::int64_t n, float alpha, const float *a, const float *x,
@@ -1837,14 +1838,16 @@ static inline cl::sycl::event gemm(backend_selector<backend::BACKEND> selector,
18371838

18381839
static inline cl::sycl::event gemm(backend_selector<backend::BACKEND> selector, transpose transa,
18391840
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
1840-
half alpha, const half *a, std::int64_t lda, const half *b,
1841-
std::int64_t ldb, half beta, half *c, std::int64_t ldc,
1841+
sycl::half alpha, const sycl::half *a, std::int64_t lda,
1842+
const sycl::half *b, std::int64_t ldb, sycl::half beta,
1843+
sycl::half *c, std::int64_t ldc,
18421844
const std::vector<cl::sycl::event> &dependencies = {});
18431845

18441846
static inline cl::sycl::event gemm(backend_selector<backend::BACKEND> selector, transpose transa,
18451847
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
1846-
float alpha, const half *a, std::int64_t lda, const half *b,
1847-
std::int64_t ldb, float beta, float *c, std::int64_t ldc,
1848+
float alpha, const sycl::half *a, std::int64_t lda,
1849+
const sycl::half *b, std::int64_t ldb, float beta, float *c,
1850+
std::int64_t ldc,
18481851
const std::vector<cl::sycl::event> &dependencies = {});
18491852

18501853
static inline cl::sycl::event gemm(backend_selector<backend::BACKEND> selector, transpose transa,

0 commit comments

Comments
 (0)