@@ -382,6 +382,39 @@ static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose tr
382382 stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
383383}
384384
385+ static inline void gemm_batch (sycl::queue &queue, transpose transa, transpose transb,
386+ std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
387+ sycl::buffer<sycl::half, 1 > &a, std::int64_t lda,
388+ std::int64_t stride_a, sycl::buffer<sycl::half, 1 > &b,
389+ std::int64_t ldb, std::int64_t stride_b, float beta,
390+ sycl::buffer<float , 1 > &c, std::int64_t ldc, std::int64_t stride_c,
391+ std::int64_t batch_size) {
392+ detail::gemm_batch (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a, lda,
393+ stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
394+ }
395+
396+ static inline void gemm_batch (sycl::queue &queue, transpose transa, transpose transb,
397+ std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
398+ sycl::buffer<std::int8_t , 1 > &a, std::int64_t lda,
399+ std::int64_t stride_a, sycl::buffer<std::int8_t , 1 > &b,
400+ std::int64_t ldb, std::int64_t stride_b, float beta,
401+ sycl::buffer<float , 1 > &c, std::int64_t ldc, std::int64_t stride_c,
402+ std::int64_t batch_size) {
403+ detail::gemm_batch (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a, lda,
404+ stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
405+ }
406+
407+ static inline void gemm_batch (sycl::queue &queue, transpose transa, transpose transb,
408+ std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
409+ sycl::buffer<std::int8_t , 1 > &a, std::int64_t lda,
410+ std::int64_t stride_a, sycl::buffer<std::int8_t , 1 > &b,
411+ std::int64_t ldb, std::int64_t stride_b, float beta,
412+ sycl::buffer<std::int32_t , 1 > &c, std::int64_t ldc,
413+ std::int64_t stride_c, std::int64_t batch_size) {
414+ detail::gemm_batch (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a, lda,
415+ stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
416+ }
417+
385418static inline void gemm_bias (sycl::queue &queue, transpose transa, transpose transb,
386419 offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k,
387420 float alpha, sycl::buffer<int8_t , 1 > &a, std::int64_t lda,
@@ -2246,6 +2279,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa,
22462279 return done;
22472280}
22482281
2282+ static inline sycl::event gemm_batch (sycl::queue &queue, transpose *transa, transpose *transb,
2283+ std::int64_t *m, std::int64_t *n, std::int64_t *k,
2284+ float *alpha, const sycl::half **a, std::int64_t *lda,
2285+ const sycl::half **b, std::int64_t *ldb, float *beta,
2286+ float **c, std::int64_t *ldc, std::int64_t group_count,
2287+ std::int64_t *group_size,
2288+ const std::vector<sycl::event> &dependencies = {}) {
2289+ auto done =
2290+ detail::gemm_batch (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a, lda, b,
2291+ ldb, beta, c, ldc, group_count, group_size, dependencies);
2292+ return done;
2293+ }
2294+
2295+ static inline sycl::event gemm_batch (sycl::queue &queue, transpose *transa, transpose *transb,
2296+ std::int64_t *m, std::int64_t *n, std::int64_t *k,
2297+ float *alpha, const std::int8_t **a, std::int64_t *lda,
2298+ const std::int8_t **b, std::int64_t *ldb, float *beta,
2299+ float **c, std::int64_t *ldc, std::int64_t group_count,
2300+ std::int64_t *group_size,
2301+ const std::vector<sycl::event> &dependencies = {}) {
2302+ auto done =
2303+ detail::gemm_batch (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a, lda, b,
2304+ ldb, beta, c, ldc, group_count, group_size, dependencies);
2305+ return done;
2306+ }
2307+
2308+ static inline sycl::event gemm_batch (sycl::queue &queue, transpose *transa, transpose *transb,
2309+ std::int64_t *m, std::int64_t *n, std::int64_t *k,
2310+ float *alpha, const std::int8_t **a, std::int64_t *lda,
2311+ const std::int8_t **b, std::int64_t *ldb, float *beta,
2312+ std::int32_t **c, std::int64_t *ldc, std::int64_t group_count,
2313+ std::int64_t *group_size,
2314+ const std::vector<sycl::event> &dependencies = {}) {
2315+ auto done =
2316+ detail::gemm_batch (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a, lda, b,
2317+ ldb, beta, c, ldc, group_count, group_size, dependencies);
2318+ return done;
2319+ }
2320+
22492321static inline sycl::event gemm_batch (sycl::queue &queue, transpose transa, transpose transb,
22502322 std::int64_t m, std::int64_t n, std::int64_t k,
22512323 float alpha, const float *a, std::int64_t lda,
@@ -2312,6 +2384,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, trans
23122384 return done;
23132385}
23142386
2387+ static inline sycl::event gemm_batch (sycl::queue &queue, transpose transa, transpose transb,
2388+ std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
2389+ const sycl::half *a, std::int64_t lda, std::int64_t stride_a,
2390+ const sycl::half *b, std::int64_t ldb, std::int64_t stride_b,
2391+ float beta, float *c, std::int64_t ldc, std::int64_t stride_c,
2392+ std::int64_t batch_size,
2393+ const std::vector<sycl::event> &dependencies = {}) {
2394+ auto done = detail::gemm_batch (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a,
2395+ lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2396+ batch_size, dependencies);
2397+ return done;
2398+ }
2399+
2400+ static inline sycl::event gemm_batch (sycl::queue &queue, transpose transa, transpose transb,
2401+ std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
2402+ const std::int8_t *a, std::int64_t lda, std::int64_t stride_a,
2403+ const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b,
2404+ float beta, float *c, std::int64_t ldc, std::int64_t stride_c,
2405+ std::int64_t batch_size,
2406+ const std::vector<sycl::event> &dependencies = {}) {
2407+ auto done = detail::gemm_batch (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a,
2408+ lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2409+ batch_size, dependencies);
2410+ return done;
2411+ }
2412+
2413+ static inline sycl::event gemm_batch (sycl::queue &queue, transpose transa, transpose transb,
2414+ std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
2415+ const std::int8_t *a, std::int64_t lda, std::int64_t stride_a,
2416+ const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b,
2417+ float beta, std::int32_t *c, std::int64_t ldc,
2418+ std::int64_t stride_c, std::int64_t batch_size,
2419+ const std::vector<sycl::event> &dependencies = {}) {
2420+ auto done = detail::gemm_batch (get_device_id (queue), queue, transa, transb, m, n, k, alpha, a,
2421+ lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2422+ batch_size, dependencies);
2423+ return done;
2424+ }
2425+
23152426static inline sycl::event gemmt (sycl::queue &queue, uplo upper_lower, transpose transa,
23162427 transpose transb, std::int64_t n, std::int64_t k, float alpha,
23172428 const float *a, std::int64_t lda, const float *b,
0 commit comments