Skip to content

Commit 5afd8c1

Browse files
authored
[FA2] Release flash-attn-mma split-kv/q🎉 (#160)
* Update and rename flash_attn_mma_tiling.cu to flexiable_flash_attn_mma.cu * Update flexiable_flash_attn_mma.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update flexiable_flash_attn_mma.cu * Update flash_attn_mma.py * Update flash_attn_mma.py * Rename flexiable_flash_attn_mma.cu to flexiable_flash_attn_mma_split_kv.cu * Create flexiable_flash_attn_mma_split_q.cu * Update flexiable_flash_attn_mma_split_kv.cu * Update flexiable_flash_attn_mma_split_q.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update flexiable_flash_attn_mma_split_kv.cu * Update flexiable_flash_attn_mma_split_q.cu * Update flash_attn_mma_stage.cu * Update flexiable_flash_attn_mma_split_kv.cu * Update flexiable_flash_attn_mma_split_q.cu * Update flash_attn_mma.py * Update flash_attn_mma.py * Update README.md * Update README.md * Update flexiable_flash_attn_mma_split_kv.cu * Update flash_attn_mma_naive.cu * Update utils.h * Update flash_attn_mma.py * Update flexiable_flash_attn_mma_split_q.cu * support flash-attn-mma split-q * support flash-attn-mma split-q * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update .gitmodules * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * support flash-attn-mma split-q * support flash-attn-mma split-q * Update README.md * support flash-attn-mma split-q * support flash-attn-mma split-q * support flash-attn-mma split-q * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * support flash-attn-mma split-q * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md
1 parent 81404c1 commit 5afd8c1

File tree

19 files changed

+1513
-1558
lines changed

19 files changed

+1513
-1558
lines changed

.github/workflows/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ outupt
2121
bin
2222
*.log
2323
*.txt
24-
*.tex
24+
*.tex
25+
__pycache__

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ outupt
2121
bin
2222
*.log
2323
*.txt
24-
*.tex
24+
*.tex
25+
__pycache__

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[submodule "third-party/cutlass"]
22
path = third-party/cutlass
33
url = https://github.com/NVIDIA/cutlass.git
4-
tag = v3.5.1
4+
tag = v3.5.1

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,4 +671,4 @@ into proprietary programs. If your program is a subroutine library, you
671671
may consider it more useful to permit linking proprietary applications with
672672
the library. If this is what you want to do, use the GNU Lesser General
673673
Public License instead of this License. But first, please read
674-
<https://www.gnu.org/licenses/why-not-lgpl.html>.
674+
<https://www.gnu.org/licenses/why-not-lgpl.html>.

README.md

Lines changed: 114 additions & 45 deletions
Large diffs are not rendered by default.

kernels/flash-attn/README.md

Lines changed: 144 additions & 376 deletions
Large diffs are not rendered by default.

kernels/flash-attn/cutlass/flash_attn_cute.cu

Whitespace-only changes.

kernels/flash-attn/flash_attn_mma.py

Lines changed: 88 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def get_args():
3636
parser.add_argument("--no-rand-k", '--no-rk', action="store_true")
3737
parser.add_argument("--no-rand-v", '--no-rv', action="store_true")
3838
parser.add_argument("--no-rand-qkv", '--no-rqkv', action="store_true")
39-
parser.add_argument("--naive", action="store_true")
40-
parser.add_argument("--sdpa", action="store_true")
39+
parser.add_argument("--run-torch-unfused", '--torch', action="store_true")
40+
parser.add_argument("--run-torch-sdpa", '--sdpa', action="store_true")
4141
parser.add_argument("--check", action="store_true")
4242
parser.add_argument("--show-all", '--show', action="store_true")
4343
parser.add_argument("--B", type=int, default=None)
@@ -46,6 +46,7 @@ def get_args():
4646
parser.add_argument("--D", type=int, default=None)
4747
parser.add_argument("--seed", type=int, default=None)
4848
parser.add_argument("--debug", action="store_true")
49+
parser.add_argument("--verbose", '--v', action="store_true")
4950
parser.add_argument("--warmup", type=int, default=2)
5051
parser.add_argument("--iters", type=int, default=10)
5152
parser.add_argument("--range-k", '--gk', action="store_true")
@@ -59,10 +60,10 @@ def get_args():
5960
# Load the CUDA kernel as a python module
6061
lib = load(name='flash_attn_lib',
6162
sources=[
62-
'./naive/flash_attn_cuda.cu',
63-
'./mma/flash_attn_mma_naive.cu',
64-
'./mma/flash_attn_mma_stage.cu',
65-
'./pybind/flash_attn.cc'],
63+
'./mma/flash_attn_mma_split_kv.cu',
64+
'./mma/flash_attn_mma_split_q.cu',
65+
'./pybind/flash_attn.cc'
66+
],
6667
extra_cuda_cflags=[
6768
"-O3",
6869
"-U__CUDA_NO_HALF_OPERATORS__",
@@ -72,10 +73,43 @@ def get_args():
7273
"--expt-relaxed-constexpr",
7374
"--expt-extended-lambda",
7475
"--use_fast_math",
76+
"-Xptxas -v",
77+
"-diag-suppress 177",
7578
f"-I {project_dir}/kernels/flash-attn/utils",
7679
"-DFLASH_ATTN_MMA_DEBUG" if args.debug else ""
7780
],
78-
extra_cflags=['-std=c++17'])
81+
extra_cflags=['-std=c++17'],
82+
verbose=args.verbose)
83+
84+
85+
def get_mha_tflops(B, H, N, D, T=1.0):
86+
# Q @ K^T FLOPs
87+
flops_qk = B * H * N * N * (2 * D - 1)
88+
89+
# Scaling FLOPs
90+
flops_scaling = B * H * N * N
91+
92+
# Safe_Softmax FLOPs
93+
flops_row_max = B * H * N * (N - 1) # row max
94+
flops_subtract_max = B * H * N * N # sub max
95+
flops_exp = B * H * N * N # pointwise exp
96+
flops_row_sum = B * H * N * (N - 1) # row sum
97+
flops_normalization = B * H * N * N # 归一化
98+
99+
flops_safe_softmax = flops_row_max + flops_subtract_max + flops_exp + flops_row_sum + flops_normalization
100+
101+
# P @ V FLOPs
102+
flops_pv = B * H * N * D * (2 * N - 1)
103+
104+
# Total FLOPs
105+
total_flops = flops_qk + flops_scaling + flops_safe_softmax + flops_pv
106+
107+
# Convert to TFLOPS
108+
# 1 TFLOPS = 10^12 FLOPS
109+
# ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
110+
tflops = total_flops * 1e-12 / (T)
111+
112+
return tflops
79113

80114

81115
def run_benchmark(perf_func: callable,
@@ -123,8 +157,14 @@ def run_benchmark(perf_func: callable,
123157
out = perf_func(q, k, v)
124158
torch.cuda.synchronize()
125159
end = time.time()
160+
total_secs = (end - start)
126161
total_time = (end - start) * 1000 # ms
127162
mean_time = total_time / iters
163+
mean_secs = total_secs / iters
164+
B, H, N, D = q.size()
165+
if "flash" in tag:
166+
B, N, H, D = q.size()
167+
TFLOPS = get_mha_tflops(B, H, N, D, mean_secs)
128168
out_info = f"{tag}"
129169
out_val_first = out.flatten()[:3].detach().cpu().numpy().tolist()
130170
out_val_last = out.flatten()[-3:].detach().cpu().numpy().tolist()
@@ -133,10 +173,11 @@ def run_benchmark(perf_func: callable,
133173
out_val = out_val_first[:2]
134174
out_val.append(out_val_last[-1])
135175
out_val = [f"{v:<12}" for v in out_val]
136-
print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms")
176+
print(f"{out_info:>25}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
137177
if show_all:
138178
print(out)
139179
time.sleep(0.05)
180+
torch.cuda.synchronize()
140181
return out.clone(), mean_time
141182

142183

@@ -159,18 +200,38 @@ def get_qkvo(B, H, N, D):
159200
v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
160201

161202
o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
203+
tk = k.transpose(-2, -1).contiguous()
204+
fq = q.transpose(1, 2).contiguous()
205+
fk = k.transpose(1, 2).contiguous()
206+
fv = v.transpose(1, 2).contiguous()
162207

163-
return q, k, v, o
208+
return q, k, v, o, tk, fq, fk, fv
164209

165210

166211
# un-fused naive attn
167-
def naive_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
212+
def unfused_standard_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
168213
att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
169214
att = F.softmax(att, dim=-1)
170215
y = att @ v
171216
return y
172217

173218

219+
def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
220+
tag: str = "out_mma", show_all: bool = False):
221+
out_flash = out_flash.transpose(1, 2)
222+
if show_all:
223+
for i in range(int(N/8)):
224+
if i < 4:
225+
print("-" * 120)
226+
print(f"out_flash[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
227+
print(out_flash[:, :, (i*8):(i+1)*8, :].float())
228+
print(f"{tag}[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
229+
print(out_mma[:, :, (i*8):(i+1)*8, :].float())
230+
print("-" * 120)
231+
all_close = torch.allclose(out_flash.float(), out_mma.float(), atol=1e-2)
232+
print(f"out_flash vs {tag}: {all_close}")
233+
234+
174235
Bs = [1, 2, 4] if not args.B else [args.B]
175236
Hs = [1, 4, 8] if not args.H else [args.H]
176237
Ns = [1024, 2048] if not args.N else [args.N]
@@ -180,42 +241,28 @@ def naive_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
180241

181242
seed = args.seed if args.seed else random.choice(range(10000))
182243
set_rand_seed(seed)
183-
print("-" * 100)
184-
print(" "* 10 + f"B: batch_size, H: n_head, N: seq_len, D: head_dim, "
244+
print("-" * 120)
245+
print(" "* 20 + f"B: batch_size, H: n_head, N: seq_len, D: head_dim, "
185246
f"seed: {seed}, Warmup: {args.warmup}, Iters: {args.iters}")
186247

187248
for (B, H, N, D) in BHNDs:
188-
print("-" * 100)
189-
print(" " * 25 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}")
190-
q, k, v, o = get_qkvo(B, H, N, D)
191-
tk = k.transpose(-2, -1).contiguous()
192-
fq = q.transpose(1, 2).contiguous()
193-
fk = k.transpose(1, 2).contiguous()
194-
fv = v.transpose(1, 2).contiguous()
249+
print("-" * 120)
250+
print(" " * 30 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}")
251+
q, k, v, o, tk, fq, fk, fv = get_qkvo(B, H, N, D)
195252
torch.cuda.synchronize()
196253

197-
if args.naive:
198-
out_naive, _ = run_benchmark(naive_attn, q, k, v, "naive(unfused)")
199-
200-
# using fp16 Tesor Core MMA instruction
201-
out_mma_naive, _ = run_benchmark(lib.flash_attn_mma_naive, q, k, v, "mma(naive)", o)
202-
out_mma_stage1, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage1)", o, stages=1)
203-
out_mma_stage2, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage2)", o, stages=2)
204-
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
205-
206-
if args.sdpa:
207-
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
208-
print("-" * 100)
254+
if args.run_torch_unfused:
255+
out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)")
256+
out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1)
257+
out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2)
258+
out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1)
259+
out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2)
260+
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
261+
if args.run_torch_sdpa:
262+
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
263+
print("-" * 120)
209264

210265
torch.cuda.synchronize()
211266
if args.check:
212-
out_flash = out_flash.transpose(1, 2)
213-
for i in range(int(N/8)):
214-
if i < 4:
215-
print("-" * 100)
216-
print(f"out_flash[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
217-
print(out_flash[:, :, (i*8):(i+1)*8, :].float())
218-
print(f"out_mma_stage1[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
219-
print(out_mma_stage1[:, :, (i*8):(i+1)*8, :].float())
220-
print("-" * 100)
221-
print(f"{torch.allclose(out_flash.float(), out_mma_naive.float(), atol=1e-2)}")
267+
check_all_close(out_flash, out_mma_split_kv1, "out_mma_split_kv1", args.show_all)
268+
check_all_close(out_flash, out_mma_split_q1, "out_mma_split_q1", args.show_all)

0 commit comments

Comments
 (0)