Skip to content

Commit d474791

Browse files
authored
[FA2] split-q + tiling-qk D=512 performance🎉 (#177)
* Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update flash_attn_mma.py * Update README.md * Update README.md * Update README.md
1 parent 697e06f commit d474791

File tree

3 files changed

+64
-14
lines changed

3 files changed

+64
-14
lines changed

README.md

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
5555
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
5656
|✔️|✔️|✔️|?|
5757

58-
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop):
58+
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION). However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~
5959

60+
- Example: B=1, H=8, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
6061
```bash
6162
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
6263
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
@@ -72,6 +73,27 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDI
7273
(flash): ['-0.00516129 ', '0.05783081 ', '-0.00027728 '], time:3.776550ms, TFLOPS:37.10
7374
----------------------------------------------------------------------------------------------------------------------------------
7475
```
76+
77+
- Example: B=1, H=48, N=8192, `D=512` (RTX 3080), FA2 not supported, `QK Tiling` Faster than SDPA~🎉🎉
78+
```bash
79+
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
80+
------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
81+
mma(split-q+tiling-qk+stage1): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:48.775554ms, TFLOPS:22.60 (+0.00%)
82+
mma(split-q+tiling-qk+stage2): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:47.503424ms, TFLOPS:23.20 (+2.68%)
83+
(sdpa): ['-0.00438309 ', '0.02174377 ', '-0.01551056 '], time:66.486573ms, TFLOPS:16.58
84+
----------------------------------------------------------------------------------------------------------------------------------
85+
```
86+
87+
- Example: B=1, H=48, N=8192, `D=512` (NVIDIA L20), FA2 not supported, `QK Tiling` Faster than SDPA~🎉🎉
88+
```bash
89+
python3 flash_attn_mma.py --B 1 --H 48 --D 512 --N 16384 --show-all --check --iters 10 --sdpa
90+
-----------------------------------------B=1, H=48, N=16384, D=512, Warmup: 1, Iters: 10------------------------------------------
91+
mma(split-q+tiling-qk+stage1): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:387.384224ms, TFLOPS:68.28 (+0.00%)
92+
mma(split-q+tiling-qk+stage2): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:325.593209ms, TFLOPS:81.24 (+18.98%)
93+
(sdpa): ['0.00790405 ', '-0.02330017 ', '0.00875854 '], time:452.067018ms, TFLOPS:58.51
94+
----------------------------------------------------------------------------------------------------------------------------------
95+
```
96+
7597
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
7698

7799
- 📚 Split KV (Basic, FlashAttention-1)
@@ -128,9 +150,10 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half*
128150
<div id="mma-tiling-qk"></div>
129151

130152
```C++
131-
// Fine-grained tiling (MMA level) for Q/K, it cause constant SRAM size 64*kMmaAtomK for Q/K,
132-
// and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
133-
// Thus, we can extend D(headdim) to 1024. Performance is stay tuned for updates ~
153+
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
154+
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
155+
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
156+
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
134157
__global__ void // Q, K, V, O -> [B, H, N, D]
135158
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
136159
```
@@ -150,14 +173,14 @@ flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half*
150173

151174
<div id="cuda-kernel"></div>
152175

153-
The kernels listed here will guide you through a step-by-step progression, ranging from easy to very challenging topics. The **Workflow** will look like: custom **CUDA** kernel impl -> **PyTorch** Python bindings -> Run tests. 👉TIPS: `*` = Tensor Cores (WMMA, MMA, CuTe), otherwise, CUDA Cores; `/` = not supported; `✔️` = supported; `` = TODO. Contents:
176+
The kernels listed here will guide you through a step-by-step progression, ranging from easy to very challenging topics. The **Workflow** for each topic will look like: custom **CUDA** kernel impl -> **PyTorch** Python bindings -> Run tests. 👉TIPS: `*` = Tensor Cores (WMMA, MMA, CuTe), otherwise, CUDA Cores; `/` = not supported; `✔️` = supported; `` = TODO. Contents are listed below:
154177

155178
- [📚 Easy ⭐️](#cuda-kernel-easy-medium)
156179
- [📚 Medium ⭐️⭐️](#cuda-kernel-easy-medium)
157180
- [📚 Hard ⭐️⭐️⭐️](#cuda-kernel-hard)
158181
- [📚 Hard++ ⭐⭐⭐️⭐️⭐️](#cuda-kernel-hard)
159182

160-
[📚 Easy](#cuda-kernel-easy-medium) and [📚 Medium](#cuda-kernel-easy-medium) sections cover fundamental operations such as element-wise, mat_trans, warp/block reduce, online-softmax, nms, layer-norm, rms-norm, dot-prod etc. [📚 Hard](#cuda-kernel-hard) and [📚 Hard++](#cuda-kernel-hard) sections delve deeper into advanced topics, primarily focusing on operations like `sgemv, sgemm, hgemv, hgemm and flash-attention`. These sections also provide numerous kernels implemented using Tensor Cores with pure MMA PTX instructions.
183+
[📚 Easy](#cuda-kernel-easy-medium) and [📚 Medium](#cuda-kernel-easy-medium) sections cover operations such as `element-wise, mat_trans, warp/block reduce, online-softmax, nms, layer-norm, rms-norm, dot-prod, relu, gelu, swish, embedding` and basic usages for `FP32/FP16/BF16/FP8` . [📚 Hard](#cuda-kernel-hard) and [📚 Hard++](#cuda-kernel-hard) sections delve deeper into advanced topics, primarily focusing on operations like `sgemv, sgemm, hgemv, hgemm and flash-attention`. These sections also provide numerous kernels implemented using Tensor Cores with pure MMA PTX.
161184

162185
### 📚 Easy ⭐️ & Medium ⭐️⭐️ ([©️back👆🏻](#cuda-kernel))
163186
<div id="cuda-kernel-easy-medium"></div>

kernels/flash-attn/README.md

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
1313
|✔️|✔️|✔️|?|
1414

15-
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop. However, for large-scale attention computations, there remains a performance gap. Performance optimizations are ongoing; stay tuned for updates.
15+
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~
1616

17-
- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
17+
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION).
18+
19+
20+
- Example: B=1, H=8, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
1821
```bash
1922
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
2023
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
@@ -31,7 +34,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDI
3134
----------------------------------------------------------------------------------------------------------------------------------
3235
```
3336

34-
- Example: B=1, H=48, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
37+
- Example: B=1, H=48, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
3538
```bash
3639
python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
3740
------------------------------------------B=1, H=48, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
@@ -47,7 +50,7 @@ python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 --iters 10 --torch # NVI
4750
(flash): ['-0.00041986 ', '0.03292847 ', '0.01330566 '], time:22.468138ms, TFLOPS:37.42
4851
----------------------------------------------------------------------------------------------------------------------------------
4952
```
50-
- Example: B=1, H=48, N=8192, D=512 (NVIDIA RTX 3080 Laptop), FA2 not supported.
53+
- Example: B=1, H=48, N=8192, `D=512` (NVIDIA RTX 3080 Laptop), FA2 not supported, `QK Tiling` Faster than SDPA~🎉🎉
5154
```bash
5255
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
5356
------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
@@ -57,6 +60,16 @@ python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D
5760
----------------------------------------------------------------------------------------------------------------------------------
5861
```
5962

63+
- Example: B=1, H=48, N=8192, `D=512` (NVIDIA L20), FA2 not supported, `QK Tiling` Faster than SDPA~🎉🎉
64+
```bash
65+
python3 flash_attn_mma.py --B 1 --H 48 --D 512 --N 16384 --show-all --check --iters 10 --sdpa
66+
-----------------------------------------B=1, H=48, N=16384, D=512, Warmup: 1, Iters: 10------------------------------------------
67+
mma(split-q+tiling-qk+stage1): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:387.384224ms, TFLOPS:68.28 (+0.00%)
68+
mma(split-q+tiling-qk+stage2): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:325.593209ms, TFLOPS:81.24 (+18.98%)
69+
(sdpa): ['0.00790405 ', '-0.02330017 ', '0.00875854 '], time:452.067018ms, TFLOPS:58.51
70+
----------------------------------------------------------------------------------------------------------------------------------
71+
```
72+
6073
## 📖 Contents
6174

6275
- [📖 FlashAttetion MMA Kernels](#mma)
@@ -114,9 +127,10 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half*
114127
<div id="mma-tiling-qk"></div>
115128
116129
```C++
117-
// Fine-grained tiling (MMA level) for Q/K, it cause constant SRAM size 64*kMmaAtomK for Q/K,
118-
// and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
119-
// Thus, we can extend D(headdim) to 1024. Performance is stay tuned for updates ~
130+
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
131+
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
132+
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
133+
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
120134
__global__ void // Q, K, V, O -> [B, H, N, D]
121135
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
122136
```

kernels/flash-attn/flash_attn_mma.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
import math
33
import time
44
import torch
5+
from torch import Tensor
56
from torch.nn import functional as F
67
from torch.utils.cpp_extension import load
78
from typing import Optional
9+
from torch.nn.attention import sdpa_kernel, SDPBackend
810
from flash_attn import flash_attn_func
11+
from functools import partial
912
import argparse
1013
import random
1114
import numpy as np
@@ -263,6 +266,16 @@ def unfused_standard_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
263266
return y
264267

265268

269+
def sdpa(q: Tensor, k: Tensor, v: Tensor, use_flash: bool = False):
270+
if not use_flash:
271+
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
272+
out: Tensor = F.scaled_dot_product_attention(q, k, v)
273+
else:
274+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
275+
out: Tensor = F.scaled_dot_product_attention(q, k, v)
276+
return out
277+
278+
266279
def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
267280
tag: str = "out_mma", check_all: bool = False,
268281
is_flash: bool = True):
@@ -330,7 +343,7 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
330343
if D <= 256:
331344
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
332345
if args.run_torch_sdpa:
333-
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
346+
out_sdpa, _ = run_benchmark(partial(sdpa, use_flash=(D<=256)), q, k, v, "(sdpa)")
334347
pretty_print_line()
335348

336349
torch.cuda.synchronize()

0 commit comments

Comments
 (0)