Skip to content

Commit 1492631

Browse files
authored
[HGEMM] Add some note to collective store (#103)
* Update hgemm_mma_stage.cu * Update README.md * Update README.md
1 parent 2f5740b commit 1492631

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@
159159
| ✔️ [hgemv_k16_f16](./hgemv/hgemv.cu)|f16|f16|[link](./hgemv/)|⭐️⭐️⭐️|
160160
| ✔️ [flash_attn_1_fwd_f32](./flash-attn/flash_attn.cu)|f32|f32|[link](./flash-attn)|⭐️⭐️⭐️|
161161
| ✔️ [flash_attn_2_fwd_f16_m16n8k16*](./flash-attn/flash_attn_mma.cu)|f16|f16|[link](./flash-attn)|⭐️⭐️⭐️|
162-
| ✔️ [hard_nms cpp only](./nms/nms.cc)|f32|/|/|⭐️|
162+
| ✔️ [nms_kernel](./nms/nms.cu)|f32|/|[link](./nms)|⭐️⭐️|
163163
| ✔️ [notes v1(deprecated)](./notes-v1.cu)|f32|f32|/|⭐️|
164164

165165
👉TIPS: * means using **Tensor Cores(MMA/WMMA)**, otherwise, using CUDA Cores by default.

hgemm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
- NVIDIA L20
3535

36-
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX和col major的layout实现smem swizzle,[点击查看性能数据](#NV-L20)
36+
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle/permute(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX和col major的layout实现smem swizzle/permute[点击查看性能数据](#NV-L20)
3737

3838
- NVIDIA GeForce RTX 3080 Laptop
3939

hgemm/hgemm_mma_stage.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
10151015
}
10161016
}
10171017

1018+
// collective store with reg reuse & warp shuffle
10181019
for (int i = 0; i < WARP_TILE_M; ++i) {
10191020
// reuse RA[2][4][4] reg here, this may boost 0.3~0.5 TFLOPS up.
10201021
// may not put 'if' in N loop, it will crash the 'pragma unroll' hint ?

0 commit comments

Comments
 (0)