Skip to content

Commit f9fa8f0

Browse files
authored
[FA2] flash-attn-mma tiling-qk for large d⚡️ (#173)
* Update README.md * Update README.md * Update README.md * Update README.md * Rename flash_attn_mma_swizzle_qkv.cu to flash_attn_mma_tiling.cu * Update flash_attn_mma.py * Update flash_attn.cc * Update flash_attn_mma_tiling.cu * Update flash_attn_mma_share_kv.cu * Update flash_attn_mma_share_qkv.cu * Update flash_attn_mma_split_q.cu * Update flash_attn_mma_tiling.cu * Update flash_attn_mma_tiling.cu * Update flash_attn_mma_tiling.cu * Create flash_attn_mma_full_tiling.cu * Update flash_attn_mma_tiling.cu * Update flash_attn_mma.py * Update flash_attn_mma_tiling.cu * Update flash_attn_mma_tiling.cu * Update flash_attn_mma_tiling.cu * fix split-q tiling stage 1 precision error * Update README.md * Update README.md * Update README.md * fix split-q tiling stage 1 precision error * support flash-attn-mma tiling-qk * support flash-attn-mma tiling-qk * support flash-attn-mma tiling-qk * support flash-attn-mma tiling-qk * support flash-attn-mma tiling-qk * support flash-attn-mma tiling-qk * support flash-attn-mma tiling-qk * support flash-attn-mma tiling-qk * support flash-attn-mma tiling-qk * support flash-attn-mma tiling-qk
1 parent c737e03 commit f9fa8f0

9 files changed

+518
-388
lines changed

README.md

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,18 @@ Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run fa
6161

6262
```bash
6363
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
64-
------------------------------------------------------------------------------------------------------------------------
65-
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10
66-
------------------------------------------------------------------------------------------------------------------------
67-
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
68-
torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
69-
mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29
70-
mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70
71-
mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50
72-
mma(split-q+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.981400ms, TFLOPS:35.19
73-
mma(split-q+share-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.866197ms, TFLOPS:48.89
74-
mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21
75-
mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06
76-
mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52
77-
(flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
78-
------------------------------------------------------------------------------------------------------------------------
64+
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
65+
torch(unfused): ['-0.00514603 ', '0.05783081 ', '-0.00026727 '], time:20.999861ms, TFLOPS:6.67 (+0.00%)
66+
mma(split-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.120730ms, TFLOPS:27.36 (+310.10%)
67+
mma(split-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.004287ms, TFLOPS:28.00 (+2.33%)
68+
mma(split-q+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.462291ms, TFLOPS:40.47 (+44.54%)
69+
mma(split-q+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.658915ms, TFLOPS:38.30
70+
mma(split-q+share-qkv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.551699ms, TFLOPS:54.91 (+35.69%)
71+
mma(split-q+share-qkv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.532172ms, TFLOPS:55.34 (+0.77%)
72+
mma(split-q+share-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.776575ms, TFLOPS:50.46
73+
mma(split-q+share-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.596927ms, TFLOPS:53.96
74+
(flash): ['-0.00516129 ', '0.05783081 ', '-0.00027728 '], time:3.776550ms, TFLOPS:37.10
75+
----------------------------------------------------------------------------------------------------------------------------------
7976
```
8077
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
8178

@@ -128,7 +125,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, half* K, half* V, half*
128125
__global__ void // Q, K, V, O -> [B, H, N, D]
129126
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
130127
```
128+
- 📚 Split Q + QK Fine-grained Tiling (**O(16xd) SRAM** vs FA2 **O(4xBrxd) SRAM**, `Headdim -> 1024`)
131129

130+
<div id="mma-tiling-qk"></div>
131+
132+
```C++
133+
// Fine-grained tiling (MMA level) for Q/K, it cause constant SRAM size 64*kMmaAtomK for Q/K,
134+
// and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
135+
// Thus, we can extend D(headdim) to 1024. Performance is stay tuned for updates ~
136+
__global__ void // Q, K, V, O -> [B, H, N, D]
137+
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
138+
```
132139
## ©️Citations🎉🎉
133140
134141
```BibTeX
@@ -151,8 +158,9 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half*
151158
| ✔️ [nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️|
152159
| ✔️ [flash_attn_mma_stages_split_kv*](./kernels/flash-attn/mma/flash_attn_mma_split_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
153160
| ✔️ [flash_attn_mma_stages_split_q*](./kernels/flash-attn/mma/flash_attn_mma_split_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
154-
| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
155-
| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
161+
| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
162+
| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
163+
| ✔️ [flash_attn_mma_stages...tiling_qk*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
156164
| ✔️ [sgemm_naive_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️|
157165
| ✔️ [sgemm_sliced_k_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
158166
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|

kernels/flash-attn/README.md

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,46 @@ This repository's implementation of FlashAttention is intended solely for learni
1717
- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
1818
```bash
1919
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
20-
------------------------------------------------------------------------------------------------------------------------
21-
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10
22-
------------------------------------------------------------------------------------------------------------------------
23-
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
24-
torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
25-
mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29
26-
mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70
27-
mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50
28-
mma(split-q+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.981400ms, TFLOPS:35.19
29-
mma(split-q+share-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.866197ms, TFLOPS:48.89
30-
mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21
31-
mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06
32-
mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52
33-
(flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
34-
------------------------------------------------------------------------------------------------------------------------
20+
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
21+
torch(unfused): ['-0.00514603 ', '0.05783081 ', '-0.00026727 '], time:20.999861ms, TFLOPS:6.67 (+0.00%)
22+
mma(split-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.120730ms, TFLOPS:27.36 (+310.10%)
23+
mma(split-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.004287ms, TFLOPS:28.00 (+2.33%)
24+
mma(split-q+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.462291ms, TFLOPS:40.47 (+44.54%)
25+
mma(split-q+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.658915ms, TFLOPS:38.30
26+
mma(split-q+share-qkv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.551699ms, TFLOPS:54.91 (+35.69%)
27+
mma(split-q+share-qkv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.532172ms, TFLOPS:55.34 (+0.77%)
28+
mma(split-q+share-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.776575ms, TFLOPS:50.46
29+
mma(split-q+share-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.596927ms, TFLOPS:53.96
30+
(flash): ['-0.00516129 ', '0.05783081 ', '-0.00027728 '], time:3.776550ms, TFLOPS:37.10
31+
----------------------------------------------------------------------------------------------------------------------------------
32+
```
33+
34+
- Example: B=1, H=48, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
35+
```bash
36+
python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
37+
------------------------------------------B=1, H=48, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
38+
torch(unfused): ['-0.00043964 ', '0.03292847 ', '0.01331329 '], time:1708.712411ms, TFLOPS:0.49 (+0.00%)
39+
mma(split-kv+stage1): ['-0.00042009 ', '0.03286743 ', '0.01330566 '], time:32.308507ms, TFLOPS:26.02 (+5188.74%)
40+
mma(split-kv+stage2): ['-0.00042009 ', '0.03286743 ', '0.01330566 '], time:31.260324ms, TFLOPS:26.89 (+3.35%)
41+
mma(split-q+stage1): ['-0.00042009 ', '0.03286743 ', '0.01330566 '], time:23.505139ms, TFLOPS:35.77 (+32.99%)
42+
mma(split-q+stage2): ['-0.00042009 ', '0.03286743 ', '0.01330566 '], time:24.225831ms, TFLOPS:34.70
43+
mma(split-q+share-qkv+stage1): ['-0.00042009 ', '0.03286743 ', '0.01330566 '], time:17.338157ms, TFLOPS:48.49 (+35.57%)
44+
mma(split-q+share-qkv+stage2): ['-0.00042009 ', '0.03286743 ', '0.01330566 '], time:17.652464ms, TFLOPS:47.63
45+
mma(split-q+share-kv+stage1): ['-0.00042009 ', '0.03286743 ', '0.01330566 '], time:18.073559ms, TFLOPS:46.52
46+
mma(split-q+share-kv+stage2): ['-0.00042009 ', '0.03286743 ', '0.01330566 '], time:17.378855ms, TFLOPS:48.38
47+
(flash): ['-0.00041986 ', '0.03292847 ', '0.01330566 '], time:22.468138ms, TFLOPS:37.42
48+
----------------------------------------------------------------------------------------------------------------------------------
49+
```
50+
- Example: B=1, H=48, N=8192, D=512 (NVIDIA RTX 3080 Laptop), FA2 not supported.
51+
```bash
52+
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
53+
-----------------------B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1041, Warmup: 1, Iters: 10------------------------
54+
----------------------------------------------------------------------------------------------------------------------------------
55+
------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
56+
mma(split-q+tiling-qk+stage1): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:48.775554ms, TFLOPS:22.60 (+0.00%)
57+
mma(split-q+tiling-qk+stage2): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:47.503424ms, TFLOPS:23.20 (+2.68%)
58+
(sdpa): ['-0.00438309 ', '0.02174377 ', '-0.01551056 '], time:66.486573ms, TFLOPS:16.58
59+
----------------------------------------------------------------------------------------------------------------------------------
3560
```
3661

3762
## 📖 Contents
@@ -41,6 +66,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDI
4166
- [📚 Split Q ](#mma-split-q)
4267
- [📚 Shared KV SMEM](#mma-share-kv)
4368
- [📚 Fully Shared QKV SMEM](#mma-share-qkv)
69+
- [📚 QK Fine-grained Tiling](#mma-tiling-qk)
4470
- [📖 Prerequisites](#prerequisites)
4571
- [📖 Installation](#install)
4672
- [📖 Performance](#perf)
@@ -84,6 +110,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, half* K, half* V, half*
84110
// and reduce Q SMEM IO-Access.
85111
__global__ void // Q, K, V, O -> [B, H, N, D]
86112
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
113+
```
114+
- 📚 Split Q + QK Fine-grained Tiling (**O(16xd) SRAM** vs FA2 **O(4xBrxd) SRAM**, `Headdim -> 1024`)
115+
116+
<div id="mma-tiling-qk"></div>
117+
118+
```C++
119+
// Fine-grained tiling (MMA level) for Q/K, it cause constant SRAM size 64*kMmaAtomK for Q/K,
120+
// and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
121+
// Thus, we can extend D(headdim) to 1024. Performance is stay tuned for updates ~
122+
__global__ void // Q, K, V, O -> [B, H, N, D]
123+
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
87124
```
88125

89126
## 📖 Prerequisites

0 commit comments

Comments
 (0)