Skip to content

Commit 4687e1d

Browse files
authored
[FA2] flash-attn-mma get rid of transpose-k✔️ (#169)
* Update flash_attn_mma_split_kv.cu * Update flash_attn_mma_split_q.cu * Update flash_attn_mma_share_kv.cu * Update flash_attn_mma_share_qkv.cu * Update flash_attn_mma.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md
1 parent 9324ddf commit 4687e1d

File tree

7 files changed

+300
-280
lines changed

7 files changed

+300
-280
lines changed

README.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
6060
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop):
6161

6262
```bash
63-
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa
63+
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
6464
------------------------------------------------------------------------------------------------------------------------
6565
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10
6666
------------------------------------------------------------------------------------------------------------------------
6767
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
68-
torch(unfused): ['-0.00887299 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
68+
torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
6969
mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29
7070
mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70
7171
mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50
@@ -74,8 +74,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa
7474
mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21
7575
mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06
7676
mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52
77-
(flash): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
78-
(sdpa): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.542566ms, TFLOPS:39.55
77+
(flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
7978
------------------------------------------------------------------------------------------------------------------------
8079
```
8180
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
@@ -93,7 +92,7 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
9392
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
9493
__global__ void
9594
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
96-
half* K, // [B, H, D, N] K^T transposed
95+
half* K, // [B, H, N, D]
9796
half* V, // [B, H, N, D]
9897
half* O, // [B, H, N, D]
9998
int QKV_seqlen);
@@ -113,7 +112,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
113112
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
114113
__global__ void
115114
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
116-
half* K, // [B, H, D, N] K^T transposed
115+
half* K, // [B, H, N, D]
117116
half* V, // [B, H, N, D]
118117
half* O, // [B, H, N, D]
119118
int QKV_seqlen);
@@ -125,23 +124,24 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
125124
```C++
126125
// K, V shared the same shared memory, improve block occupancy.
127126
__global__ void
128-
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
129-
half* K,
130-
half* V,
131-
half* O,
127+
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // [B, H, N, D]
128+
half* K, // [B, H, N, D]
129+
half* V, // [B, H, N, D]
130+
half* O, // [B, H, N, D]
132131
int QKV_seqlen);
133132
```
134133
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
135134
136135
<div id="mma-share-qkv"></div>
137136
138137
```C++
139-
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
138+
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
139+
// and reduce Q SMEM IO-Access.
140140
__global__ void
141-
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
142-
half* K,
143-
half* V,
144-
half* O,
141+
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // [B, H, N, D]
142+
half* K, // [B, H, N, D]
143+
half* V, // [B, H, N, D]
144+
half* O, // [B, H, N, D]
145145
int QKV_seqlen);
146146
```
147147

kernels/flash-attn/README.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ This repository's implementation of FlashAttention is intended solely for learni
1616

1717
- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
1818
```bash
19-
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa # NVIDIA RTX 3080 Laptop
19+
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
2020
------------------------------------------------------------------------------------------------------------------------
2121
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10
2222
------------------------------------------------------------------------------------------------------------------------
2323
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
24-
torch(unfused): ['-0.00887299 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
24+
torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
2525
mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29
2626
mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70
2727
mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50
@@ -30,8 +30,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa
3030
mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21
3131
mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06
3232
mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52
33-
(flash): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
34-
(sdpa): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.542566ms, TFLOPS:39.55
33+
(flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
3534
------------------------------------------------------------------------------------------------------------------------
3635
```
3736

@@ -67,7 +66,7 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
6766
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
6867
__global__ void
6968
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
70-
half* K, // [B, H, D, N] K^T transposed
69+
half* K, // [B, H, N, D]
7170
half* V, // [B, H, N, D]
7271
half* O, // [B, H, N, D]
7372
int QKV_seqlen);
@@ -87,7 +86,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
8786
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
8887
__global__ void
8988
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
90-
half* K, // [B, H, D, N] K^T transposed
89+
half* K, // [B, H, N, D]
9190
half* V, // [B, H, N, D]
9291
half* O, // [B, H, N, D]
9392
int QKV_seqlen);
@@ -99,23 +98,24 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
9998
```C++
10099
// K, V shared the same shared memory, improve block occupancy.
101100
__global__ void
102-
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
103-
half* K,
104-
half* V,
105-
half* O,
101+
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // [B, H, N, D]
102+
half* K, // [B, H, N, D]
103+
half* V, // [B, H, N, D]
104+
half* O, // [B, H, N, D]
106105
int QKV_seqlen);
107106
```
108107
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
109108
110109
<div id="mma-share-qkv"></div>
111110
112111
```C++
113-
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
112+
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
113+
// and reduce Q SMEM IO-Access.
114114
__global__ void
115-
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
116-
half* K,
117-
half* V,
118-
half* O,
115+
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // [B, H, N, D]
116+
half* K, // [B, H, N, D]
117+
half* V, // [B, H, N, D]
118+
half* O, // [B, H, N, D]
119119
int QKV_seqlen);
120120
```
121121

kernels/flash-attn/flash_attn_mma.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def run_benchmark(perf_func: callable,
176176
out_val = out_val_first[:2]
177177
out_val.append(out_val_last[-1])
178178
out_val = [f"{v:<12}" for v in out_val]
179-
print(f"{out_info:>30}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
179+
print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
180180
if show_all:
181181
print(out)
182182
time.sleep(args.sleep)
@@ -203,12 +203,12 @@ def get_qkvo(B, H, N, D):
203203
v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
204204

205205
o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
206-
tk = k.transpose(-2, -1).contiguous()
206+
# transpose (H,N) -> (N,H) for FA2.
207207
fq = q.transpose(1, 2).contiguous()
208208
fk = k.transpose(1, 2).contiguous()
209209
fv = v.transpose(1, 2).contiguous()
210210

211-
return q, k, v, o, tk, fq, fk, fv
211+
return q, k, v, o, fq, fk, fv
212212

213213

214214
# un-fused naive attn
@@ -233,7 +233,7 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
233233
print("-" * 120)
234234
diff = torch.abs(out_flash.float() - out_mma.float())
235235
all_close = str(torch.allclose(out_flash.float(), out_mma.float(), atol=1e-2))
236-
print(f"out_flash vs {tag:<20}, all close: {all_close:<6}, "
236+
print(f"out_flash vs {tag:<18}, all close: {all_close:<6}, "
237237
f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, "
238238
f"mean diff: {diff.mean().item():.6f}")
239239

@@ -254,19 +254,19 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
254254
for (B, H, N, D) in BHNDs:
255255
print("-" * 120)
256256
print(" " * 30 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}")
257-
q, k, v, o, tk, fq, fk, fv = get_qkvo(B, H, N, D)
257+
q, k, v, o, fq, fk, fv = get_qkvo(B, H, N, D)
258258
torch.cuda.synchronize()
259259

260260
if args.run_torch_unfused:
261261
out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)")
262-
out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1)
263-
out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2)
264-
out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1)
265-
out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2)
266-
out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1)
267-
out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage2)", o, stages=2)
268-
out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
269-
out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
262+
out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage1)", o, stages=1)
263+
out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage2)", o, stages=2)
264+
out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage1)", o, stages=1)
265+
out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage2)", o, stages=2)
266+
out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
267+
out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
268+
out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage1)", o, stages=1)
269+
out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage2)", o, stages=2)
270270
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
271271
if args.run_torch_sdpa:
272272
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")

0 commit comments

Comments
 (0)