Skip to content

Commit 82f1d04

Browse files
authored
[FA2] tiling-qkv F32/F16 + swizzle q/qk/qkv🎉 (#213)
* Update flash_attn_mma_tiling_qkv.cu * Update flash_attn_mma_tiling_qkv_F32F16F16F32.cu * Update flash_attn_mma_tiling_qkv_swizzle_q.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update flash_attn_mma_tiling_qkv_swizzle_qk.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update flash_attn_mma_tiling_qkv_swizzle_qkv.cu * Update flash_attn_mma_tiling_qkv_swizzle_q.cu * Update flash_attn_mma_tiling_qkv_swizzle_qk.cu * Update flash_attn_mma_tiling_qkv_swizzle_q.cu * Update flash_attn_mma_tiling_qkv_swizzle_qkv.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update flash_attn_mma_share_kv.cu * Update flash_attn_mma_share_kv_F32F16F16F32.cu * Update flash_attn_mma_split_kv.cu * Update flash_attn_mma_split_q.cu * Update flash_attn_mma_tiling_qk.cu * Update flash_attn_mma_tiling_qk_F32F16F16F32.cu * Update flash_attn_mma_tiling_qkv.cu * Update flash_attn_mma_tiling_qkv_F32F16F16F32.cu * Create flash_attn_mma_tiling_qkv_swizzle_q_F32F16F16F32.cu * Create flash_attn_mma_tiling_qkv_swizzle_qk_F32F16F16F32.cu * Create flash_attn_mma_tiling_qkv_swizzle_qkv_F32F16F16F32.cu * Update README.md * Update README.md * Update flash_attn_mma.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update flash_attn_mma_tiling_qkv_swizzle_q.cu * Update flash_attn_mma_tiling_qkv_swizzle_q.cu * Update flash_attn_mma_tiling_qkv_swizzle_qk.cu * Update flash_attn_mma_tiling_qkv_swizzle_qk.cu * Update flash_attn_mma_tiling_qkv_swizzle_q_F32F16F16F32.cu * Update flash_attn_mma_tiling_qkv_swizzle_qk_F32F16F16F32.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update flash_attn_mma_tiling_qkv_F32F16F16F32.cu * Update flash_attn_mma_tiling_qkv_swizzle_qk_F32F16F16F32.cu * Update flash_attn_mma_tiling_qkv_swizzle_qkv_F32F16F16F32.cu * Update flash_attn_mma.py * Update flash_attn.cc * Update README.md * Update README.md
1 parent c687988 commit 82f1d04

18 files changed

+6509
-156
lines changed

README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
5555
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QKV Fine-grained Tiling**|
5656
|✔️|✔️|✔️|✔️|
5757

58-
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
58+
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **92 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16/F32, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
5959

6060
|Algorithm| (B,H,N,D) | RTX 3080 Laptop | L20 | RTX 4090 |
6161
|:---:|:---:|:---:|:---:|:---:|
6262
|FlashAttention-2|(1,8,8192,64)|37 TFLOPS|100 TFLOPS|145 TFLOPS|
63-
|split-q+share-qkv+stage2|(1,8,8192,64)|**55 TFLOPS**|99 TFLOPS|**221 TFLOPS**|
63+
|share-qkv+stage2|(1,8,8192,64)|**55 TFLOPS**|99 TFLOPS|**221 TFLOPS**|
6464
|FlashAttention-2|(1,48,8192,64)|37 TFLOPS|109 TFLOPS|163 TFLOPS|
65-
|split-q+share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
65+
|share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
6666
|SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS|
67-
|split-q+tiling-qkv+stage2|(1,48,8192,512)|**23 TFLOPS**|**90 TFLOPS**|**135 TFLOPS**|
67+
|tiling-qkv+swizzle-qk+stage2|(1,48,8192,512)|**23 TFLOPS**|**92 TFLOPS**|**157 TFLOPS**|
6868
|Precision Errors vs FA2/SDPA| / | max: < ~1e-3 | min: ~0.0 | mean: < ~1e-5 |
6969

7070
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
@@ -369,9 +369,12 @@ The kernels listed here will guide you through a step-by-step progression, rangi
369369
| ✔️ [flash_attn_mma...tiling_qk_swizzle{q}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
370370
| ✔️ [flash_attn_mma...tiling_qk_swizzle{qk}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
371371
| ✔️ [flash_attn_mma...tiling_qk_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
372-
| ? [flash_attn_mma...tiling_qkv_swizzle{q}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
373-
| ? [flash_attn_mma...tiling_qkv_swizzle{qk}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
374-
| ? [flash_attn_mma...tiling_qkv_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
372+
| ✔️ [flash_attn_mma...tiling_qkv_swizzle{q}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
373+
| ✔️ [flash_attn_mma...tiling_qkv_swizzle{qk}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
374+
| ✔️ [flash_attn_mma...tiling_qkv_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
375+
| ✔️ [flash_attn...tiling_qkv_swizzle{q}{f32}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
376+
| ✔️ [flash_attn...tiling_qkv_swizzle{qk}{f32}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
377+
| ✔️ [flash_attn...tiling_qkv_swizzle{qkv}{f32}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
375378

376379
**rr**: means reduce registers usage (for `d>128`); **f32**: means MMA accumulate with FP32 dtype, otherwise, FP16. softmax Acc dtype is always be FP32 for high precision; **swizzle**: now, only support smem swizzle for MMA.
377380

kernels/flash-attn/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
1313
|✔️|✔️|✔️|✔️|
1414

15-
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (MMA Acc F16, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
15+
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (MMA Acc F16/F32, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
1616

1717
|Algorithm| (B,H,N,D) | NVIDIA RTX 3080 Laptop | NVIDIA L20 | NVIDIA GeForce RTX 4090 |
1818
|:---:|:---:|:---:|:---:|:---:|
1919
|FlashAttention-2|(1,8,8192,64)|37 TFLOPS|100 TFLOPS|145 TFLOPS|
20-
|split-q+share-qkv+stage2|(1,8,8192,64)|**55 TFLOPS**|99 TFLOPS|**221 TFLOPS**|
20+
|share-qkv+stage2|(1,8,8192,64)|**55 TFLOPS**|99 TFLOPS|**221 TFLOPS**|
2121
|FlashAttention-2|(1,48,8192,64)|37 TFLOPS|109 TFLOPS|163 TFLOPS|
22-
|split-q+share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
22+
|share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
2323
|SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS|
24-
|split-q+tiling-qkv+stage2|(1,48,8192,512)|**23 TFLOPS**|**90 TFLOPS**|**135 TFLOPS**|
24+
|tiling-qkv+swizzle-qk+stage2|(1,48,8192,512)|**23 TFLOPS**|**92 TFLOPS**|**157 TFLOPS**|
2525
|Precision Errors vs FA2/SDPA| / | max: < ~1e-3 | min: ~0.0 | mean: < ~1e-5 |
2626

27-
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~
27+
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **92 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~
2828

2929
## 📖 Contents
3030

0 commit comments

Comments
 (0)