Skip to content

Commit b08b41f

Browse files
AidanBeltonSnormallytangent
authored andcommitted
[BLAS] Add new batch_gemm types (#466)
Add support for more batch_gemm types to follow the specification. Some combination using int8 are disabled on some backends due to precision issue.
1 parent 7cdd083 commit b08b41f

32 files changed

+2679
-392
lines changed

include/oneapi/mkl/blas.hxx

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,39 @@ static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose tr
382382
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
383383
}
384384

385+
static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
386+
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
387+
sycl::buffer<sycl::half, 1> &a, std::int64_t lda,
388+
std::int64_t stride_a, sycl::buffer<sycl::half, 1> &b,
389+
std::int64_t ldb, std::int64_t stride_b, float beta,
390+
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
391+
std::int64_t batch_size) {
392+
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda,
393+
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
394+
}
395+
396+
static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
397+
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
398+
sycl::buffer<std::int8_t, 1> &a, std::int64_t lda,
399+
std::int64_t stride_a, sycl::buffer<std::int8_t, 1> &b,
400+
std::int64_t ldb, std::int64_t stride_b, float beta,
401+
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
402+
std::int64_t batch_size) {
403+
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda,
404+
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
405+
}
406+
407+
static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
408+
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
409+
sycl::buffer<std::int8_t, 1> &a, std::int64_t lda,
410+
std::int64_t stride_a, sycl::buffer<std::int8_t, 1> &b,
411+
std::int64_t ldb, std::int64_t stride_b, float beta,
412+
sycl::buffer<std::int32_t, 1> &c, std::int64_t ldc,
413+
std::int64_t stride_c, std::int64_t batch_size) {
414+
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda,
415+
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
416+
}
417+
385418
static inline void gemm_bias(sycl::queue &queue, transpose transa, transpose transb,
386419
offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k,
387420
float alpha, sycl::buffer<int8_t, 1> &a, std::int64_t lda,
@@ -2246,6 +2279,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa,
22462279
return done;
22472280
}
22482281

2282+
static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb,
2283+
std::int64_t *m, std::int64_t *n, std::int64_t *k,
2284+
float *alpha, const sycl::half **a, std::int64_t *lda,
2285+
const sycl::half **b, std::int64_t *ldb, float *beta,
2286+
float **c, std::int64_t *ldc, std::int64_t group_count,
2287+
std::int64_t *group_size,
2288+
const std::vector<sycl::event> &dependencies = {}) {
2289+
auto done =
2290+
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b,
2291+
ldb, beta, c, ldc, group_count, group_size, dependencies);
2292+
return done;
2293+
}
2294+
2295+
static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb,
2296+
std::int64_t *m, std::int64_t *n, std::int64_t *k,
2297+
float *alpha, const std::int8_t **a, std::int64_t *lda,
2298+
const std::int8_t **b, std::int64_t *ldb, float *beta,
2299+
float **c, std::int64_t *ldc, std::int64_t group_count,
2300+
std::int64_t *group_size,
2301+
const std::vector<sycl::event> &dependencies = {}) {
2302+
auto done =
2303+
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b,
2304+
ldb, beta, c, ldc, group_count, group_size, dependencies);
2305+
return done;
2306+
}
2307+
2308+
static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb,
2309+
std::int64_t *m, std::int64_t *n, std::int64_t *k,
2310+
float *alpha, const std::int8_t **a, std::int64_t *lda,
2311+
const std::int8_t **b, std::int64_t *ldb, float *beta,
2312+
std::int32_t **c, std::int64_t *ldc, std::int64_t group_count,
2313+
std::int64_t *group_size,
2314+
const std::vector<sycl::event> &dependencies = {}) {
2315+
auto done =
2316+
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b,
2317+
ldb, beta, c, ldc, group_count, group_size, dependencies);
2318+
return done;
2319+
}
2320+
22492321
static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
22502322
std::int64_t m, std::int64_t n, std::int64_t k,
22512323
float alpha, const float *a, std::int64_t lda,
@@ -2312,6 +2384,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, trans
23122384
return done;
23132385
}
23142386

2387+
static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
2388+
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
2389+
const sycl::half *a, std::int64_t lda, std::int64_t stride_a,
2390+
const sycl::half *b, std::int64_t ldb, std::int64_t stride_b,
2391+
float beta, float *c, std::int64_t ldc, std::int64_t stride_c,
2392+
std::int64_t batch_size,
2393+
const std::vector<sycl::event> &dependencies = {}) {
2394+
auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a,
2395+
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2396+
batch_size, dependencies);
2397+
return done;
2398+
}
2399+
2400+
static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
2401+
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
2402+
const std::int8_t *a, std::int64_t lda, std::int64_t stride_a,
2403+
const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b,
2404+
float beta, float *c, std::int64_t ldc, std::int64_t stride_c,
2405+
std::int64_t batch_size,
2406+
const std::vector<sycl::event> &dependencies = {}) {
2407+
auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a,
2408+
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2409+
batch_size, dependencies);
2410+
return done;
2411+
}
2412+
2413+
static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
2414+
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
2415+
const std::int8_t *a, std::int64_t lda, std::int64_t stride_a,
2416+
const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b,
2417+
float beta, std::int32_t *c, std::int64_t ldc,
2418+
std::int64_t stride_c, std::int64_t batch_size,
2419+
const std::vector<sycl::event> &dependencies = {}) {
2420+
auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a,
2421+
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2422+
batch_size, dependencies);
2423+
return done;
2424+
}
2425+
23152426
static inline sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa,
23162427
transpose transb, std::int64_t n, std::int64_t k, float alpha,
23172428
const float *a, std::int64_t lda, const float *b,

include/oneapi/mkl/blas/detail/blas_ct_backends.hxx

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,30 @@ static inline void gemm_batch(backend_selector<backend::BACKEND> selector, trans
464464
sycl::buffer<sycl::half, 1> &c, std::int64_t ldc,
465465
std::int64_t stride_c, std::int64_t batch_size);
466466

467+
static inline void gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
468+
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
469+
float alpha, sycl::buffer<sycl::half, 1> &a, std::int64_t lda,
470+
std::int64_t stride_a, sycl::buffer<sycl::half, 1> &b,
471+
std::int64_t ldb, std::int64_t stride_b, float beta,
472+
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
473+
std::int64_t batch_size);
474+
475+
static inline void gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
476+
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
477+
float alpha, sycl::buffer<std::int8_t, 1> &a, std::int64_t lda,
478+
std::int64_t stride_a, sycl::buffer<std::int8_t, 1> &b,
479+
std::int64_t ldb, std::int64_t stride_b, float beta,
480+
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
481+
std::int64_t batch_size);
482+
483+
static inline void gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
484+
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
485+
float alpha, sycl::buffer<std::int8_t, 1> &a, std::int64_t lda,
486+
std::int64_t stride_a, sycl::buffer<std::int8_t, 1> &b,
487+
std::int64_t ldb, std::int64_t stride_b, float beta,
488+
sycl::buffer<std::int32_t, 1> &c, std::int64_t ldc,
489+
std::int64_t stride_c, std::int64_t batch_size);
490+
467491
static inline void spmv(backend_selector<backend::BACKEND> selector, uplo upper_lower,
468492
std::int64_t n, float alpha, sycl::buffer<float, 1> &a,
469493
sycl::buffer<float, 1> &x, std::int64_t incx, float beta,
@@ -1870,6 +1894,30 @@ static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector
18701894
std::int64_t group_count, std::int64_t *group_size,
18711895
const std::vector<sycl::event> &dependencies = {});
18721896

1897+
static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose *transa,
1898+
transpose *transb, std::int64_t *m, std::int64_t *n,
1899+
std::int64_t *k, float *alpha, const sycl::half **a,
1900+
std::int64_t *lda, const sycl::half **b, std::int64_t *ldb,
1901+
float *beta, float **c, std::int64_t *ldc,
1902+
std::int64_t group_count, std::int64_t *group_size,
1903+
const std::vector<sycl::event> &dependencies = {});
1904+
1905+
static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose *transa,
1906+
transpose *transb, std::int64_t *m, std::int64_t *n,
1907+
std::int64_t *k, float *alpha, const std::int8_t **a,
1908+
std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb,
1909+
float *beta, float **c, std::int64_t *ldc,
1910+
std::int64_t group_count, std::int64_t *group_size,
1911+
const std::vector<sycl::event> &dependencies = {});
1912+
1913+
static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose *transa,
1914+
transpose *transb, std::int64_t *m, std::int64_t *n,
1915+
std::int64_t *k, float *alpha, const std::int8_t **a,
1916+
std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb,
1917+
float *beta, std::int32_t **c, std::int64_t *ldc,
1918+
std::int64_t group_count, std::int64_t *group_size,
1919+
const std::vector<sycl::event> &dependencies = {});
1920+
18731921
static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector,
18741922
transpose transa, transpose transb, std::int64_t m,
18751923
std::int64_t n, std::int64_t k, float alpha,
@@ -1911,6 +1959,33 @@ static inline sycl::event gemm_batch(
19111959
sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c,
19121960
std::int64_t batch_size, const std::vector<sycl::event> &dependencies = {});
19131961

1962+
static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
1963+
transpose transb, std::int64_t m, std::int64_t n,
1964+
std::int64_t k, float alpha, const sycl::half *a,
1965+
std::int64_t lda, std::int64_t stride_a, const sycl::half *b,
1966+
std::int64_t ldb, std::int64_t stride_b, float beta, float *c,
1967+
std::int64_t ldc, std::int64_t stride_c,
1968+
std::int64_t batch_size,
1969+
const std::vector<sycl::event> &dependencies = {});
1970+
1971+
static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
1972+
transpose transb, std::int64_t m, std::int64_t n,
1973+
std::int64_t k, float alpha, const std::int8_t *a,
1974+
std::int64_t lda, std::int64_t stride_a, const std::int8_t *b,
1975+
std::int64_t ldb, std::int64_t stride_b, float beta, float *c,
1976+
std::int64_t ldc, std::int64_t stride_c,
1977+
std::int64_t batch_size,
1978+
const std::vector<sycl::event> &dependencies = {});
1979+
1980+
static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
1981+
transpose transb, std::int64_t m, std::int64_t n,
1982+
std::int64_t k, float alpha, const std::int8_t *a,
1983+
std::int64_t lda, std::int64_t stride_a, const std::int8_t *b,
1984+
std::int64_t ldb, std::int64_t stride_b, float beta,
1985+
std::int32_t *c, std::int64_t ldc, std::int64_t stride_c,
1986+
std::int64_t batch_size,
1987+
const std::vector<sycl::event> &dependencies = {});
1988+
19141989
static inline sycl::event spmv(backend_selector<backend::BACKEND> selector, uplo upper_lower,
19151990
std::int64_t n, float alpha, const float *a, const float *x,
19161991
std::int64_t incx, float beta, float *y, std::int64_t incy,

0 commit comments

Comments
 (0)