@@ -57,12 +57,12 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
57
57
58
58
Currently, for small-scale attention ` (B<=4, H <=48, SeqLen <= 8192) ` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [ 📚 Split Q + Fully Shared QKV SMEM] ( #mma-share-qkv ) can achieve ** 55 TFLOPS (D=64)** that almost ** ~ 1.5x** 🎉 faster than FA2. On NVIDIA L20, [ 📚 Split Q + QK Fine-grained Tiling] ( #mma-tiling-qk ) can achieve ** 81 TFLOPS (D=512)** that almost ** ~ 1.4x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (👇Benchmark)
59
59
60
- | Algorithm| (B,H,N,D) | RTX 3080 | L20 | RTX 4090 |
60
+ | Algorithm| (B,H,N,D) | RTX 3080 Laptop | L20 | RTX 4090 |
61
61
| :---:| :---:| :---:| :---:| :---:|
62
62
| FlashAttention-2| (1,8,8192,64)| 37 TFLOPS| 100 TFLOPS| 145 TFLOPS|
63
63
| split-q+share-qkv+stage2| (1,8,8192,64)| ** 55 TFLOPS** | 99 TFLOPS| ** 218 TFLOPS** |
64
64
| FlashAttention-2| (1,48,8192,64)| 37 TFLOPS| 109 TFLOPS| 163 TFLOPS|
65
- | split-q+share-qkv+stage2| (1,48,8192,64)| 35 TFLOPS| 107 TFLOPS| ** 220 TFLOPS** |
65
+ | split-q+share-qkv+stage2| (1,48,8192,64)| ** 48 TFLOPS** | 107 TFLOPS| ** 220 TFLOPS** |
66
66
| SDPA(EFFICIENT ATTENTION)| (1,48,8192,512)| 16 TFLOPS| 58 TFLOPS| 85 TFLOPS|
67
67
| split-q+tiling-qk+swizzle-qk+stage2| (1,48,8192,512)| ** 23 TFLOPS** | ** 81 TFLOPS** | ** 127 TFLOPS** |
68
68
@@ -316,17 +316,17 @@ The kernels listed here will guide you through a step-by-step progression, rangi
316
316
| :---| :---| :---| :---| :---|
317
317
| ✔️ [ flash_attn_mma_stages_split_kv* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_split_kv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
318
318
| ✔️ [ flash_attn_mma_stages_split_q* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_split_q.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
319
- | ✔️ [ flash_attn_mma_stages...shared_kv* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_share_kv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
320
- | ✔️ [ flash_attn_mma_stages...shared_qkv* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_share_qkv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
321
- | ✔️ [ flash_attn_mma_stages...tiling_qk* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qk.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
322
- | ✔️ [ flash_attn_mma...shared_kv_swizzle{q}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_kv_swizzle_q.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
323
- | ✔️ [ flash_attn_mma...shared_kv_swizzle{qk}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_kv_swizzle_qk.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
324
- | ✔️ [ flash_attn_mma...shared_kv_swizzle{qkv}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_kv_swizzle_qkv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
325
- | ✔️ [ flash_attn_mma...shared_qkv_swizzle{q}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_qkv_swizzle_q.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
326
- | ✔️ [ flash_attn_mma...shared_qkv_swizzle{qk}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_qkv_swizzle_qk.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
327
- | ✔️ [ flash_attn_mma...shared_qkv_swizzle{qkv}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_qkv_swizzle_qkv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
328
- | ✔️ [ flash_attn_mma...tiling_qk_swizzle{q}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_q.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
329
- | ✔️ [ flash_attn_mma...tiling_qk_swizzle{qk}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qk.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️ |
319
+ | ✔️ [ flash_attn_mma_stages...shared_kv* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_share_kv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
320
+ | ✔️ [ flash_attn_mma_stages...shared_qkv* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_share_qkv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
321
+ | ✔️ [ flash_attn_mma_stages...tiling_qk* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_tiling_qk.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
322
+ | ✔️ [ flash_attn_mma...shared_kv_swizzle{q}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_kv_swizzle_q.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
323
+ | ✔️ [ flash_attn_mma...shared_kv_swizzle{qk}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_kv_swizzle_qk.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
324
+ | ✔️ [ flash_attn_mma...shared_kv_swizzle{qkv}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_kv_swizzle_qkv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
325
+ | ✔️ [ flash_attn_mma...shared_qkv_swizzle{q}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_qkv_swizzle_q.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
326
+ | ✔️ [ flash_attn_mma...shared_qkv_swizzle{qk}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_qkv_swizzle_qk.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
327
+ | ✔️ [ flash_attn_mma...shared_qkv_swizzle{qkv}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_share_qkv_swizzle_qkv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
328
+ | ✔️ [ flash_attn_mma...tiling_qk_swizzle{q}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_q.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
329
+ | ✔️ [ flash_attn_mma...tiling_qk_swizzle{qk}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qk.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
330
330
| ✔️ [ flash_attn_mma...tiling_qk_swizzle{qkv}* ] ( ./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu ) | f16| f16| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️|
331
331
| ? [ flash_attn_mma_stages_split_q{f32}* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_split_q_acc_f32.cu ) | f16| f32| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️|
332
332
| ? [ flash_attn_mma_stages...shared_kv{f32}* ] ( ./kernels/flash-attn/mma/basic/flash_attn_mma_share_kv_acc_f32.cu ) | f16| f32| [ link] ( ./kernels/flash-attn ) | ⭐️⭐️⭐️⭐️⭐️|
0 commit comments