Skip to content

Commit c687988

Browse files
authored
[FA2] flash-attn-mma fully tiling-qkv🎉 (#212)
* Update README.md * Update README.md
1 parent 14fa9e7 commit c687988

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
4141
|✔️|✔️|✔️|✔️|
4242

4343

44-
I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Shared KV SMEM, **Fully Shared QKV SMEM**, **Prefetch Q s2r**, **Prefetch K/V g2s**, **QK Fine-grained Tiling**, Collective Store, etc. Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details.
44+
I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Shared KV SMEM, **Fully Shared QKV SMEM**, **Prefetch Q s2r**, **Prefetch K/V g2s**, **QKV Fine-grained Tiling**, Collective Store, etc. Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details.
4545

4646
![flash-attn-mma](https://github.com/user-attachments/assets/6f66796d-44d5-4ec1-b224-af997bd152b2)
4747

@@ -55,7 +55,7 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
5555
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QKV Fine-grained Tiling**|
5656
|✔️|✔️|✔️|✔️|
5757

58-
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + QKV Fully Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
58+
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
5959

6060
|Algorithm| (B,H,N,D) | RTX 3080 Laptop | L20 | RTX 4090 |
6161
|:---:|:---:|:---:|:---:|:---:|
@@ -131,7 +131,7 @@ __global__ void // Q, K, V, O -> [B, H, N, D]
131131
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
132132
```
133133
134-
- 📚 Split Q + QKV Fully Fine-grained Tiling (**O(Brx16)~O(1) SRAM** vs FA2 **O(4xBrxd) SRAM**)
134+
- 📚 Split Q + Fully QKV Fine-grained Tiling (**O(Brx16)~O(1) SRAM** vs FA2 **O(4xBrxd) SRAM**)
135135
136136
<div id="mma-tiling-qkv"></div>
137137

kernels/flash-attn/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ This repository's implementation of FlashAttention is intended solely for learni
2424
|split-q+tiling-qkv+stage2|(1,48,8192,512)|**23 TFLOPS**|**90 TFLOPS**|**135 TFLOPS**|
2525
|Precision Errors vs FA2/SDPA| / | max: < ~1e-3 | min: ~0.0 | mean: < ~1e-5 |
2626

27-
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + QKV Fully Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~
27+
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~
2828

2929
## 📖 Contents
3030

@@ -34,7 +34,7 @@ For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#
3434
- [📚 Shared KV SMEM](#mma-share-kv)
3535
- [📚 Fully Shared QKV SMEM](#mma-share-qkv)
3636
- [📚 QK Fine-grained Tiling](#mma-tiling-qk)
37-
- [📚 QKV Fully Fine-grained Tiling](#mma-tiling-qkv)
37+
- [📚 Fully QKV Fine-grained Tiling](#mma-tiling-qkv)
3838
- [📖 Prerequisites](#prerequisites)
3939
- [📖 Installation](#install)
4040
- [📖 Performance](#perf)
@@ -92,7 +92,7 @@ __global__ void // Q, K, V, O -> [B, H, N, D]
9292
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
9393
```
9494

95-
- 📚 Split Q + QKV Fully Fine-grained Tiling (**O(Brx16)~O(1) SRAM** vs FA2 **O(4xBrxd) SRAM**)
95+
- 📚 Split Q + Fully QKV Fine-grained Tiling (**O(Brx16)~O(1) SRAM** vs FA2 **O(4xBrxd) SRAM**)
9696

9797
<div id="mma-tiling-qkv"></div>
9898

0 commit comments

Comments
 (0)