Skip to content

Commit 3b56750

Browse files
authored
[SGEMM][Async] Add K16 + Copy Async Kernel (#65)
* Update hgemm_async.cu * Update sgemm_async.cu * Update sgemm.cu * Update sgemm.py * Update README.md * Update README.md
1 parent 195158a commit 3b56750

File tree

6 files changed

+1033
-163
lines changed

6 files changed

+1033
-163
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@
100100
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
101101
| ✔️ [sgemm_t_8x8_sliced_k...bcf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
102102
| ✔️ [sgemm_t_8x8_sliced_k...dbuf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
103-
| ✔️ [sgemm_t_8x8_sliced_k...async](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
103+
| ✔️ [sgemm_t_8x8_sliced_k16...dbuf](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
104+
| ✔️ [sgemm_t_8x8_sliced_k16...async](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
104105
| ✔️ [hgemm_naive_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️|
105106
| ✔️ [hgemm_sliced_k_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
106107
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|

hgemm/hgemm_async.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
2020
#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)
2121
#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n))
22-
// ca(cache all): support 4, 8, 16 bytes, cg(cache global): only support 16 bytes.
22+
// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes.
2323
#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
2424
#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
2525

26+
2627
template<const int BM=128, const int BN=128, const int BK=16,
2728
const int TM=8, const int TN=8, const int OFFSET=0>
2829
__global__ void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_kernel(

0 commit comments

Comments
 (0)