@@ -17,21 +17,46 @@ This repository's implementation of FlashAttention is intended solely for learni
17
17
- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
18
18
``` bash
19
19
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
20
- ------------------------------------------------------------------------------------------------------------------------
21
- B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10
22
- ------------------------------------------------------------------------------------------------------------------------
23
- B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
24
- torch(unfused): [' -0.0088729 ' , ' -0.00307083 ' , ' 0.00674438 ' ], time:19.318247ms, TFLOPS:7.25
25
- mma(split-kv+stage1): [' -0.0089035 ' , ' -0.00307846 ' , ' 0.00675964 ' ], time:5.330205ms, TFLOPS:26.29
26
- mma(split-kv+stage2): [' -0.0089035 ' , ' -0.00307846 ' , ' 0.00675964 ' ], time:5.058098ms, TFLOPS:27.70
27
- mma(split-q+stage1): [' -0.0089035 ' , ' -0.00307846 ' , ' 0.00675964 ' ], time:3.639126ms, TFLOPS:38.50
28
- mma(split-q+stage2): [' -0.0089035 ' , ' -0.00307846 ' , ' 0.00675964 ' ], time:3.981400ms, TFLOPS:35.19
29
- mma(split-q+share-kv+stage1): [' -0.0089035 ' , ' -0.00307846 ' , ' 0.00675964 ' ], time:2.866197ms, TFLOPS:48.89
30
- mma(split-q+share-kv+stage2): [' -0.0089035 ' , ' -0.00307846 ' , ' 0.00675964 ' ], time:2.584863ms, TFLOPS:54.21
31
- mma(split-q+share-qkv+stage1): [' -0.0089035 ' , ' -0.00307846 ' , ' 0.00675964 ' ], time:2.691698ms, TFLOPS:52.06
32
- mma(split-q+share-qkv+stage2): [' -0.0089035 ' , ' -0.00307846 ' , ' 0.00675964 ' ], time:2.569842ms, TFLOPS:54.52
33
- (flash): [' -0.0088653 ' , ' -0.00307836 ' , ' 0.00675201 ' ], time:3.734636ms, TFLOPS:37.52
34
- ------------------------------------------------------------------------------------------------------------------------
20
+ -------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
21
+ torch(unfused): [' -0.00514603 ' , ' 0.05783081 ' , ' -0.00026727 ' ], time:20.999861ms, TFLOPS:6.67 (+0.00%)
22
+ mma(split-kv+stage1): [' -0.00511169 ' , ' 0.05795288 ' , ' -0.00029612 ' ], time:5.120730ms, TFLOPS:27.36 (+310.10%)
23
+ mma(split-kv+stage2): [' -0.00511169 ' , ' 0.05795288 ' , ' -0.00029612 ' ], time:5.004287ms, TFLOPS:28.00 (+2.33%)
24
+ mma(split-q+stage1): [' -0.00511169 ' , ' 0.05795288 ' , ' -0.00029612 ' ], time:3.462291ms, TFLOPS:40.47 (+44.54%)
25
+ mma(split-q+stage2): [' -0.00511169 ' , ' 0.05795288 ' , ' -0.00029612 ' ], time:3.658915ms, TFLOPS:38.30
26
+ mma(split-q+share-qkv+stage1): [' -0.00511169 ' , ' 0.05795288 ' , ' -0.00029612 ' ], time:2.551699ms, TFLOPS:54.91 (+35.69%)
27
+ mma(split-q+share-qkv+stage2): [' -0.00511169 ' , ' 0.05795288 ' , ' -0.00029612 ' ], time:2.532172ms, TFLOPS:55.34 (+0.77%)
28
+ mma(split-q+share-kv+stage1): [' -0.00511169 ' , ' 0.05795288 ' , ' -0.00029612 ' ], time:2.776575ms, TFLOPS:50.46
29
+ mma(split-q+share-kv+stage2): [' -0.00511169 ' , ' 0.05795288 ' , ' -0.00029612 ' ], time:2.596927ms, TFLOPS:53.96
30
+ (flash): [' -0.00516129 ' , ' 0.05783081 ' , ' -0.00027728 ' ], time:3.776550ms, TFLOPS:37.10
31
+ ----------------------------------------------------------------------------------------------------------------------------------
32
+ ```
33
+
34
+ - Example: B=1, H=48, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
35
+ ``` bash
36
+ python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
37
+ ------------------------------------------B=1, H=48, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
38
+ torch(unfused): [' -0.00043964 ' , ' 0.03292847 ' , ' 0.01331329 ' ], time:1708.712411ms, TFLOPS:0.49 (+0.00%)
39
+ mma(split-kv+stage1): [' -0.00042009 ' , ' 0.03286743 ' , ' 0.01330566 ' ], time:32.308507ms, TFLOPS:26.02 (+5188.74%)
40
+ mma(split-kv+stage2): [' -0.00042009 ' , ' 0.03286743 ' , ' 0.01330566 ' ], time:31.260324ms, TFLOPS:26.89 (+3.35%)
41
+ mma(split-q+stage1): [' -0.00042009 ' , ' 0.03286743 ' , ' 0.01330566 ' ], time:23.505139ms, TFLOPS:35.77 (+32.99%)
42
+ mma(split-q+stage2): [' -0.00042009 ' , ' 0.03286743 ' , ' 0.01330566 ' ], time:24.225831ms, TFLOPS:34.70
43
+ mma(split-q+share-qkv+stage1): [' -0.00042009 ' , ' 0.03286743 ' , ' 0.01330566 ' ], time:17.338157ms, TFLOPS:48.49 (+35.57%)
44
+ mma(split-q+share-qkv+stage2): [' -0.00042009 ' , ' 0.03286743 ' , ' 0.01330566 ' ], time:17.652464ms, TFLOPS:47.63
45
+ mma(split-q+share-kv+stage1): [' -0.00042009 ' , ' 0.03286743 ' , ' 0.01330566 ' ], time:18.073559ms, TFLOPS:46.52
46
+ mma(split-q+share-kv+stage2): [' -0.00042009 ' , ' 0.03286743 ' , ' 0.01330566 ' ], time:17.378855ms, TFLOPS:48.38
47
+ (flash): [' -0.00041986 ' , ' 0.03292847 ' , ' 0.01330566 ' ], time:22.468138ms, TFLOPS:37.42
48
+ ----------------------------------------------------------------------------------------------------------------------------------
49
+ ```
50
+ - Example: B=1, H=48, N=8192, D=512 (NVIDIA RTX 3080 Laptop), FA2 not supported.
51
+ ``` bash
52
+ python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
53
+ -----------------------B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1041, Warmup: 1, Iters: 10------------------------
54
+ ----------------------------------------------------------------------------------------------------------------------------------
55
+ ------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
56
+ mma(split-q+tiling-qk+stage1): [' -0.00433731 ' , ' 0.02165222 ' , ' -0.01544189 ' ], time:48.775554ms, TFLOPS:22.60 (+0.00%)
57
+ mma(split-q+tiling-qk+stage2): [' -0.00433731 ' , ' 0.02165222 ' , ' -0.01544189 ' ], time:47.503424ms, TFLOPS:23.20 (+2.68%)
58
+ (sdpa): [' -0.00438309 ' , ' 0.02174377 ' , ' -0.01551056 ' ], time:66.486573ms, TFLOPS:16.58
59
+ ----------------------------------------------------------------------------------------------------------------------------------
35
60
```
36
61
37
62
## 📖 Contents
@@ -41,6 +66,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDI
41
66
- [ 📚 Split Q ] ( #mma-split-q )
42
67
- [ 📚 Shared KV SMEM] ( #mma-share-kv )
43
68
- [ 📚 Fully Shared QKV SMEM] ( #mma-share-qkv )
69
+ - [ 📚 QK Fine-grained Tiling] ( #mma-tiling-qk )
44
70
- [ 📖 Prerequisites] ( #prerequisites )
45
71
- [ 📖 Installation] ( #install )
46
72
- [ 📖 Performance] ( #perf )
@@ -84,6 +110,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, half* K, half* V, half*
84
110
// and reduce Q SMEM IO-Access.
85
111
__global__ void // Q, K, V, O -> [B, H, N, D]
86
112
flash_attn_mma_stages_split_q_shared_qkv_kernel (half* Q, half* K, half* V, half* O, ...);
113
+ ```
114
+ - 📚 Split Q + QK Fine-grained Tiling (**O(16xd) SRAM** vs FA2 **O(4xBrxd) SRAM**, `Headdim -> 1024`)
115
+
116
+ <div id="mma-tiling-qk"></div>
117
+
118
+ ```C++
119
+ // Fine-grained tiling (MMA level) for Q/K, it cause constant SRAM size 64*kMmaAtomK for Q/K,
120
+ // and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
121
+ // Thus, we can extend D(headdim) to 1024. Performance is stay tuned for updates ~
122
+ __global__ void // Q, K, V, O -> [B, H, N, D]
123
+ flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
87
124
```
88
125
89
126
## 📖 Prerequisites
0 commit comments