Skip to content

Commit b1b923a

Browse files
authored
[FlashAttention] Release flash-atttention-mma 0.0.1 🎉 (#158)
* Update makefile * Update .gitignore * Update hgemm_mma_stage.cu * Create flash_attn_mma.py * Delete kernels/flash-attn/flash_attn.py * Update README.md * Update hgemm_mma_stage.cu * Update hgemm_mma_stage_tn_cute.cu * Update README.md * Update README.md * Update README.md * Update hgemm_mma_stage.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Create flexiable_flash_attn_mma.cu * Create flash_qattn_mma.cu * Create flexiable_flash_qattn_mma.cu * Delete kernels/flash-attn/mma/flash_attn_mma_fp8.cu * Delete kernels/flash-attn/cutlass/flash_attn_cute_fp8.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn.cc * Update flash_attn_cuda.cu * Update flash_attn_mma_old.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.py * Update flash_attn_mma.py * add more tests * add more tests * add more tests * add more tests * add more tests * add more tests * add more tests * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Create custom_mma_utils.h * Update custom_mma_utils.h * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update custom_mma_utils.h * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma.cu * Update custom_mma_utils.h * Update flash_attn_mma.cu * Update flash_attn_mma.py * Delete kernels/flash-attn/mma/custom_mma_utils.h * Delete kernels/flash-attn/mma/flexiable_flash_qattn_mma.cu * Delete kernels/flash-attn/mma/flexiable_flash_attn_mma.cu * Delete kernels/flash-attn/mma/flash_qattn_mma.cu * Delete kernels/flash-attn/mma/flash_attn_mma_old.cu * Delete kernels/flash-attn/mma/flash_attn_mma_bak.cu * Delete kernels/flash-attn/mma/flash_attn_mma.cu * Create flash_attn_mma_naive.cu * Create flash_attn_mma_stage.cu * Create flash_attn_mma_tiling.cu * Update utils.h * Update flash_attn_cuda.cu * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma_stage.cu * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma_stage.cu * Update flash_attn_mma_tiling.cu * Update README.md * Update flash_attn_mma_naive.cu * Update README.md * Update flash_attn_mma.py * Update flash_attn_mma.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update flash_attn_mma_stage.cu * Update README.md * Update README.md * Update flash_attn_mma_stage.cu * Update README.md * Update README.md * Update README.md
1 parent a683145 commit b1b923a

25 files changed

+2724
-357
lines changed

.github/workflows/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ __pycache__
1919
*.bin
2020
outupt
2121
bin
22+
*.log
23+
*.txt
24+
*.tex

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ __pycache__
1919
*.bin
2020
outupt
2121
bin
22+
*.log
23+
*.txt
24+
*.tex

.gitmodules

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
[submodule "third-party/cutlass"]
22
path = third-party/cutlass
33
url = https://github.com/NVIDIA/cutlass.git
4-
tag = v3.5.1
5-
6-
4+
tag = v3.5.1

LICENSE

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,5 +672,3 @@ may consider it more useful to permit linking proprietary applications with
672672
the library. If this is what you want to do, use the GNU Lesser General
673673
Public License instead of this License. But first, please read
674674
<https://www.gnu.org/licenses/why-not-lgpl.html>.
675-
676-

README.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
4242
|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
4343
|✔️|✔️|✔️|✔️|
4444

45+
I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-atttention-mma⚡️⚡️](./kernels/flash-attn) for more details.
46+
47+
![flash-attn-mma](https://github.com/user-attachments/assets/3e20fdaa-9b31-4dcd-91d5-204905842dce)
48+
49+
|CUDA Cores|Sliced K (Loop over N/D)|Tile Block (Br, Bc, Bd)|MMA (m16n8k16)|
50+
|:---:|:---:|:---:|:---:|
51+
|✔️|✔️|✔️|✔️|
52+
|Pack LDST (128 bits)|SMEM Padding|Copy Async |Tile MMA (More Threads)
53+
|✔️|✔️|✔️|✔️|
54+
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|Row Major (NN)|
55+
|✔️|✔️|✔️|✔️|
56+
4557
## ©️Citations🎉🎉
4658

4759
```BibTeX
@@ -198,8 +210,9 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
198210
| ✔️ [hgemv_k32_f16](./kernels/hgemv/hgemv.cu)|f16|f16|[link](./kernels/hgemv/)|⭐️⭐️⭐️|
199211
| ✔️ [hgemv_k128_f16x4](./kernels/hgemv/hgemv.cu)|f16|f16|[link](./kernels/hgemv/)|⭐️⭐️⭐️|
200212
| ✔️ [hgemv_k16_f16](./kernels/hgemv/hgemv.cu)|f16|f16|[link](./kernels/hgemv/)|⭐️⭐️⭐️|
201-
| ✔️ [flash_attn_f32](./kernels/flash-attn/flash_attn.cu)|f32|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️|
202-
| ✔️ [flash_attn_mma_m16n8k16*](./kernels/flash-attn/flash_attn_mma.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️|
213+
| ✔️ [flash_attn_cuda](./kernels/flash-attn/naive/flash_attn_cuda.cu)|f32|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️|
214+
| ✔️ [flash_attn_mma_naive*](./kernels/flash-attn/mma/flash_attn_mma_naive.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️|
215+
| ✔️ [flash_attn_mma_stage*](./kernels/flash-attn/mma/flash_attn_mma_stage.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️|
203216
| ✔️ [nms_f32](./kernels/nms/nms.cu)|f32|/|[link](./kernels/nms)|⭐️⭐️|
204217
| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️|
205218

kernels/flash-attn/README.md

Lines changed: 349 additions & 108 deletions
Large diffs are not rendered by default.

kernels/flash-attn/cutlass/flash_attn_cute_fp8.cu

Whitespace-only changes.

kernels/flash-attn/flash_attn.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

kernels/flash-attn/flash_attn_mma.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import os
2+
import math
3+
import time
4+
import torch
5+
from torch.nn import functional as F
6+
from torch.utils.cpp_extension import load
7+
from typing import Optional
8+
from flash_attn import flash_attn_func
9+
import argparse
10+
import random
11+
import numpy as np
12+
13+
torch.set_grad_enabled(False)
14+
torch.set_printoptions(precision=6, threshold=8, edgeitems=3,
15+
linewidth=120, sci_mode=False)
16+
17+
18+
def set_rand_seed(seed:int=1):
19+
random.seed(seed)
20+
np.random.seed(seed)
21+
torch.manual_seed(seed)
22+
torch.cuda.manual_seed_all(seed)
23+
24+
25+
def get_project_dir():
26+
return os.path.dirname(os.path.dirname(
27+
os.path.dirname(os.path.abspath(__file__))))
28+
29+
30+
project_dir = get_project_dir()
31+
32+
33+
def get_args():
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument("--no-rand-q", '--no-rq', action="store_true")
36+
parser.add_argument("--no-rand-k", '--no-rk', action="store_true")
37+
parser.add_argument("--no-rand-v", '--no-rv', action="store_true")
38+
parser.add_argument("--no-rand-qkv", '--no-rqkv', action="store_true")
39+
parser.add_argument("--naive", action="store_true")
40+
parser.add_argument("--sdpa", action="store_true")
41+
parser.add_argument("--check", action="store_true")
42+
parser.add_argument("--show-all", '--show', action="store_true")
43+
parser.add_argument("--B", type=int, default=None)
44+
parser.add_argument("--H", type=int, default=None)
45+
parser.add_argument("--N", type=int, default=None)
46+
parser.add_argument("--D", type=int, default=None)
47+
parser.add_argument("--seed", type=int, default=None)
48+
parser.add_argument("--debug", action="store_true")
49+
parser.add_argument("--warmup", type=int, default=2)
50+
parser.add_argument("--iters", type=int, default=10)
51+
parser.add_argument("--range-k", '--gk', action="store_true")
52+
return parser.parse_args()
53+
54+
55+
args = get_args()
56+
print(args)
57+
58+
59+
# Load the CUDA kernel as a python module
60+
lib = load(name='flash_attn_lib',
61+
sources=[
62+
'./naive/flash_attn_cuda.cu',
63+
'./mma/flash_attn_mma_naive.cu',
64+
'./mma/flash_attn_mma_stage.cu',
65+
'./pybind/flash_attn.cc'],
66+
extra_cuda_cflags=[
67+
"-O3",
68+
"-U__CUDA_NO_HALF_OPERATORS__",
69+
"-U__CUDA_NO_HALF_CONVERSIONS__",
70+
"-U__CUDA_NO_HALF2_OPERATORS__",
71+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
72+
"--expt-relaxed-constexpr",
73+
"--expt-extended-lambda",
74+
"--use_fast_math",
75+
f"-I {project_dir}/kernels/flash-attn/utils",
76+
"-DFLASH_ATTN_MMA_DEBUG" if args.debug else ""
77+
],
78+
extra_cflags=['-std=c++17'])
79+
80+
81+
def run_benchmark(perf_func: callable,
82+
q: torch.Tensor,
83+
k: torch.Tensor,
84+
v: torch.Tensor,
85+
tag: str,
86+
out: Optional[torch.Tensor] = None,
87+
s: Optional[torch.Tensor] = None, # BUDEG
88+
stages: int = -1,
89+
warmup: int = args.warmup,
90+
iters: int = args.iters,
91+
show_all: bool = args.show_all):
92+
if out is not None:
93+
out.fill_(0)
94+
if s is not None:
95+
s.fill_(0)
96+
if out is not None:
97+
for i in range(warmup):
98+
if stages >= 1:
99+
if s is not None:
100+
perf_func(q, k, v, out, s, stages)
101+
else:
102+
perf_func(q, k, v, out, stages)
103+
else:
104+
perf_func(q, k, v, out)
105+
else:
106+
for i in range(warmup):
107+
_ = perf_func(q, k, v)
108+
109+
torch.cuda.synchronize()
110+
start = time.time()
111+
# iters
112+
if out is not None:
113+
for i in range(iters):
114+
if stages >= 1:
115+
if s is not None:
116+
perf_func(q, k, v, out, s, stages)
117+
else:
118+
perf_func(q, k, v, out, stages)
119+
else:
120+
perf_func(q, k, v, out)
121+
else:
122+
for i in range(iters):
123+
out = perf_func(q, k, v)
124+
torch.cuda.synchronize()
125+
end = time.time()
126+
total_time = (end - start) * 1000 # ms
127+
mean_time = total_time / iters
128+
out_info = f"{tag}"
129+
out_val_first = out.flatten()[:3].detach().cpu().numpy().tolist()
130+
out_val_last = out.flatten()[-3:].detach().cpu().numpy().tolist()
131+
out_val_first = [round(v, 8) for v in out_val_first]
132+
out_val_last = [round(v, 8) for v in out_val_last]
133+
out_val = out_val_first[:2]
134+
out_val.append(out_val_last[-1])
135+
out_val = [f"{v:<12}" for v in out_val]
136+
print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms")
137+
if show_all:
138+
print(out)
139+
time.sleep(0.05)
140+
return out.clone(), mean_time
141+
142+
143+
def get_qkvo(B, H, N, D):
144+
if not (args.no_rand_q or args.no_rand_qkv):
145+
q = torch.randn((B, H, N, D), dtype=torch.half, device="cuda")
146+
else:
147+
q = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
148+
if not (args.no_rand_k or args.no_rand_qkv):
149+
k = torch.randn((B, H, N, D), dtype=torch.half, device="cuda")
150+
else:
151+
k = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
152+
if args.range_k:
153+
for i in range(N):
154+
k[:, :, i, :] = (i + 1) / N
155+
k = k.cuda().half().contiguous()
156+
if not (args.no_rand_v or args.no_rand_qkv):
157+
v = torch.randn((B, H, N, D), dtype=torch.half, device="cuda")
158+
else:
159+
v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
160+
161+
o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
162+
163+
return q, k, v, o
164+
165+
166+
# un-fused naive attn
167+
def naive_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
168+
att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
169+
att = F.softmax(att, dim=-1)
170+
y = att @ v
171+
return y
172+
173+
174+
Bs = [1, 2, 4] if not args.B else [args.B]
175+
Hs = [1, 4, 8] if not args.H else [args.H]
176+
Ns = [1024, 2048] if not args.N else [args.N]
177+
Ds = [64, 128] if not args.D else [args.D]
178+
# batch_size, n_head, seq_len, head_dim (B,H,N,D)
179+
BHNDs = [(B, H, N, D) for B in Bs for H in Hs for N in Ns for D in Ds]
180+
181+
seed = args.seed if args.seed else random.choice(range(10000))
182+
set_rand_seed(seed)
183+
print("-" * 100)
184+
print(" "* 10 + f"B: batch_size, H: n_head, N: seq_len, D: head_dim, "
185+
f"seed: {seed}, Warmup: {args.warmup}, Iters: {args.iters}")
186+
187+
for (B, H, N, D) in BHNDs:
188+
print("-" * 100)
189+
print(" " * 25 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}")
190+
q, k, v, o = get_qkvo(B, H, N, D)
191+
tk = k.transpose(-2, -1).contiguous()
192+
fq = q.transpose(1, 2).contiguous()
193+
fk = k.transpose(1, 2).contiguous()
194+
fv = v.transpose(1, 2).contiguous()
195+
torch.cuda.synchronize()
196+
197+
if args.naive:
198+
out_naive, _ = run_benchmark(naive_attn, q, k, v, "naive(unfused)")
199+
200+
# using fp16 Tesor Core MMA instruction
201+
out_mma_naive, _ = run_benchmark(lib.flash_attn_mma_naive, q, k, v, "mma(naive)", o)
202+
out_mma_stage1, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage1)", o, stages=1)
203+
out_mma_stage2, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage2)", o, stages=2)
204+
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
205+
206+
if args.sdpa:
207+
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
208+
print("-" * 100)
209+
210+
torch.cuda.synchronize()
211+
if args.check:
212+
out_flash = out_flash.transpose(1, 2)
213+
for i in range(int(N/8)):
214+
if i < 4:
215+
print("-" * 100)
216+
print(f"out_flash[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
217+
print(out_flash[:, :, (i*8):(i+1)*8, :].float())
218+
print(f"out_mma_stage1[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
219+
print(out_mma_stage1[:, :, (i*8):(i+1)*8, :].float())
220+
print("-" * 100)
221+
print(f"{torch.allclose(out_flash.float(), out_mma_naive.float(), atol=1e-2)}")

0 commit comments

Comments
 (0)