Skip to content

Commit d8abbaa

Browse files
authored
[FlashAttention] Refactor flash_attn_1_fwd_f32 kernel (#33)
* [FA] Refactor flash_attn_1_fwd_f32 kernel * [FA] Refactor flash_attn_1_fwd_f32 kernel
1 parent dfabac3 commit d8abbaa

File tree

8 files changed

+497
-354
lines changed

8 files changed

+497
-354
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@
8787
| ✔️ [hgemv_k16_f16_kernel](./hgemv)|f16|f16|[link](./hgemv/)|⭐️⭐️⭐️|
8888
| ✔️ [flash_attn_1_fwd_f32_kernel](./flash-attn/flash_attn_1_fwd_f32.cu)|f32|f32|[link](./flash-attn)|⭐️⭐️⭐️|
8989
|[flash_attn_2_fwd_f32_kernel](./flash-attn/flash_attn_2_fwd_f32.cu)|f32|f32|[link](./flash-attn)|⭐️⭐️⭐️|
90-
|[flash_attn_2_fwd_f16_kernel](./flash-attn/flash_attn_2_fwd_f32.cu)|f16|f32|[link](./flash-attn)|⭐️⭐️⭐️|
91-
| [flash_attn_2_fwd_bf16_kernel](./flash-attn/flash_attn_2_fwd_f32.cu)|bf16|f32|[link](./flash-attn)|⭐️⭐️⭐️|
92-
| ✔️ [hard_nms cpp only](./nms/nms.cc)|f32|/||⭐️|
90+
|[flash_attn_2_fwd_f16_kernel](./flash-attn/flash_attn_2_fwd_f32.cu)|f16|f16|[link](./flash-attn)|⭐️⭐️⭐️|
91+
| ✔️ [flash_attn_2_fwd_f16_mma_m16n8k16](./flash-attn/flash_attn_2_fwd_f16_mma_m16n8k16.cu)|f16|f16|[link](./flash-attn)|⭐️⭐️⭐️|
92+
| ✔️ [hard_nms cpp only](./nms/nms.cc)|f32|/|/|⭐️|
9393
| ✔️ [notes v1(deprecated)](./notes-v1.cu)|f32|f32|/|⭐️|
9494

9595
## 0x01 📖 博客目录

flash-attn/README.md

Lines changed: 15 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,24 @@
1-
## FlashAttention 测试
1+
# FlashAttention
22

3-
### 前置依赖
4-
- PyTorch >= 2.2.1
5-
- CUDA >= 12.2
3+
## 0x00 说明
64

7-
```bash
8-
python3 -m pip install torch
9-
```
5+
包含以下内容:
6+
7+
- [X] flash_attn_1_fwd_f32_kernel
8+
- [ ] flash_attn_2_fwd_f32_kernel
9+
- [ ] flash_attn_2_fwd_f16_kernel
10+
- [x] flash_attn_2_fwd_f16_mma_m16n8k16_kernel
11+
- [X] PyTorch bindings
1012

1113
### 运行测试
1214
```bash
1315
python3 flash_attn.py
1416
```
15-
日志如下:(RTX 3080 Ti)
17+
日志如下:
1618
```bash
17-
python3 flash_attn.py
18-
=== profiling manual attention ===
19-
STAGE:2024-03-25 08:47:18 3818250:3818250 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
20-
STAGE:2024-03-25 08:47:18 3818250:3818250 ActivityProfilerController.cpp:320] Completed Stage: Collection
21-
STAGE:2024-03-25 08:47:18 3818250:3818250 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
22-
------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
23-
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls Total KFLOPs
24-
------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
25-
manual_attn 45.60% 513.000us 98.31% 1.106ms 1.106ms 489.000us 43.82% 1.116ms 1.116ms 1 --
26-
aten::matmul 14.31% 161.000us 40.27% 453.000us 226.500us 131.000us 11.74% 496.000us 248.000us 2 --
27-
aten::bmm 5.78% 65.000us 7.82% 88.000us 44.000us 166.000us 14.87% 166.000us 83.000us 2 201326.592
28-
aten::reshape 4.98% 56.000us 7.38% 83.000us 20.750us 74.000us 6.63% 105.000us 26.250us 4 --
29-
aten::expand 4.62% 52.000us 6.13% 69.000us 17.250us 65.000us 5.82% 90.000us 22.500us 4 --
30-
aten::transpose 3.47% 39.000us 4.27% 48.000us 48.000us 44.000us 3.94% 54.000us 54.000us 1 --
31-
aten::softmax 1.16% 13.000us 3.47% 39.000us 39.000us 17.000us 1.52% 44.000us 44.000us 1 --
32-
aten::as_strided 0.53% 6.000us 0.53% 6.000us 1.200us 35.000us 3.14% 35.000us 7.000us 5 --
33-
aten::mul 2.22% 25.000us 2.84% 32.000us 32.000us 33.000us 2.96% 33.000us 33.000us 1 786.432
34-
aten::_softmax 1.42% 16.000us 1.96% 22.000us 22.000us 27.000us 2.42% 27.000us 27.000us 1 --
35-
------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
36-
Self CPU time total: 1.125ms
37-
Self CUDA time total: 1.116ms
38-
39-
=== profiling flash_attn_1_fwd_f32 attention ===
40-
STAGE:2024-03-25 08:47:18 3818250:3818250 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
41-
STAGE:2024-03-25 08:47:18 3818250:3818250 ActivityProfilerController.cpp:320] Completed Stage: Collection
42-
STAGE:2024-03-25 08:47:18 3818250:3818250 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
43-
-------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
44-
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
45-
-------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
46-
flash_attn_1_fwd_f32 5.76% 148.000us 15.72% 404.000us 404.000us 1.804ms 96.37% 1.872ms 1.872ms 1
47-
aten::zeros_like 1.21% 31.000us 5.21% 134.000us 134.000us 8.000us 0.43% 31.000us 31.000us 1
48-
aten::zero_ 1.44% 37.000us 2.96% 76.000us 38.000us 11.000us 0.59% 25.000us 12.500us 2
49-
aten::zeros 0.78% 20.000us 2.41% 62.000us 62.000us 8.000us 0.43% 21.000us 21.000us 1
50-
aten::fill_ 0.89% 23.000us 1.60% 41.000us 13.667us 19.000us 1.01% 19.000us 6.333us 3
51-
aten::full 0.74% 19.000us 1.71% 44.000us 44.000us 9.000us 0.48% 16.000us 16.000us 1
52-
aten::empty_like 1.01% 26.000us 1.71% 44.000us 44.000us 6.000us 0.32% 8.000us 8.000us 1
53-
aten::empty 0.62% 16.000us 0.62% 16.000us 8.000us 5.000us 0.27% 5.000us 2.500us 2
54-
aten::empty_strided 0.54% 14.000us 0.54% 14.000us 14.000us 2.000us 0.11% 2.000us 2.000us 1
55-
cudaEventRecord 2.18% 56.000us 2.18% 56.000us 2.154us 0.000us 0.00% 0.000us 0.000us 26
56-
-------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
57-
Self CPU time total: 2.570ms
58-
Self CUDA time total: 1.872ms
59-
60-
attn values sanity check: True
19+
--------------------------------------------------------------------------------
20+
out_fa1fwdf32: [0.11064263, 0.08648866, -0.07250906], time:2.32403278ms
21+
out_fa1fwdf32(v2): [0.11064263, 0.08648866, -0.07250906], time:2.22899675ms
22+
out_attnf32_th: [0.11064263, 0.08648865, -0.07250906], time:0.11474848ms
23+
--------------------------------------------------------------------------------
6124
```

flash-attn/flash_attn.cc

Lines changed: 0 additions & 9 deletions
This file was deleted.

flash-attn/flash_attn.py

Lines changed: 64 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,67 +4,82 @@
44
import torch
55
from torch.nn import functional as F
66
from torch.utils.cpp_extension import load
7+
from functools import partial
8+
from typing import Optional
79

810
torch.set_grad_enabled(False)
911
# Load the CUDA kernel as a python module
10-
custom_flash_attn = load(name='custom_flash_attn',
11-
sources=[
12-
'flash_attn.cc',
13-
'flash_attn_1_fwd_f32.cu',
14-
'flash_attn_2_fwd_f32.cu'
15-
],
16-
extra_cuda_cflags=['-O2'])
12+
lib = load(name='flash_attn_lib',
13+
sources=['flash_attn_1_fwd_f32.cu'],
14+
extra_cuda_cflags=[
15+
"-O3",
16+
"-U__CUDA_NO_HALF_OPERATORS__",
17+
"-U__CUDA_NO_HALF_CONVERSIONS__",
18+
"-U__CUDA_NO_HALF2_OPERATORS__",
19+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
20+
"--expt-relaxed-constexpr",
21+
"--expt-extended-lambda",
22+
"--use_fast_math"
23+
],
24+
extra_cflags=['-std=c++17'])
1725

1826
# Use small model params, otherwise slower than manual attention. See caveats in README.
19-
batch_size = 16
20-
n_head = 12
21-
seq_len = 64
22-
head_embd = 64
2327

24-
q = torch.randn(batch_size, n_head, seq_len, head_embd).float().cuda()
25-
k = torch.randn(batch_size, n_head, seq_len, head_embd).float().cuda()
26-
v = torch.randn(batch_size, n_head, seq_len, head_embd).float().cuda()
27-
q.requires_grad = False
28-
k.requires_grad = False
29-
v.requires_grad = False
30-
print('=== profiling manual attention ===')
31-
32-
def manual_attn(q, k, v):
28+
# un-fused naive attn
29+
def manual_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
3330
att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
3431
att = F.softmax(att, dim=-1)
3532
y = att @ v
3633
return y
3734

38-
for _ in range(2):
39-
manual_result = manual_attn(q, k, v) # warmup
40-
41-
torch.cuda.synchronize()
42-
with torch.autograd.profiler.profile(use_cuda=True, with_flops=True) as prof:
43-
with torch.autograd.profiler.record_function("manual_attn"):
44-
manual_result = manual_attn(q, k, v)
45-
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
4635

47-
for _ in range(2):
48-
custom_result = custom_flash_attn.flash_attn_1_fwd_f32(q, k, v) # warmup
49-
print('=== profiling flash_attn_1_fwd_f32 attention === ')
50-
with torch.autograd.profiler.profile(use_cuda=True, with_flops=True) as prof:
51-
with torch.autograd.profiler.record_function("flash_attn_1_fwd_f32"):
52-
custom_result = custom_flash_attn.flash_attn_1_fwd_f32(q, k, v)
53-
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
54-
print('attn values sanity check:', torch.allclose(custom_result, manual_result, rtol=0, atol=1e-02))
55-
56-
# Why custom flash attn is slow than naive attn in for loop test ?
57-
REPEAT = 10
58-
manual_result = manual_attn(q, k, v) # warmup
59-
st = time.time()
60-
for _ in range(REPEAT):
61-
manual_result = manual_attn(q, k, v)
36+
def run_benchmark(perf_func: callable,
37+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
38+
tag: str, out: Optional[torch.Tensor] = None,
39+
warmup: int = 10, iters: int = 200,
40+
show_all: bool = False):
41+
if out is not None:
42+
out.fill_(0)
43+
if out is not None:
44+
for i in range(warmup):
45+
perf_func(q, k, v, out)
46+
else:
47+
for i in range(warmup):
48+
_ = perf_func(q, k, v)
49+
6250
torch.cuda.synchronize()
63-
print(f"manual attention mean time(ms): {((time.time() - st) * 1000) / REPEAT}")
64-
custom_result = custom_flash_attn.flash_attn_1_fwd_f32(q, k, v) # warmup
65-
st = time.time()
66-
for _ in range(REPEAT):
67-
custom_result = custom_flash_attn.flash_attn_1_fwd_f32(q, k, v)
51+
start = time.time()
52+
# iters
53+
if out is not None:
54+
for i in range(iters):
55+
perf_func(q, k, v, out)
56+
else:
57+
for i in range(iters):
58+
out = perf_func(q, k, v)
6859
torch.cuda.synchronize()
69-
print(f"flash_attn_1_fwd_f32 mean time(ms): {((time.time() - st) * 1000) / REPEAT}")
60+
end = time.time()
61+
total_time = (end - start) * 1000 # ms
62+
mean_time = total_time / iters
63+
out_info = f"out_{tag}"
64+
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
65+
out_val = [round(v, 8) for v in out_val]
66+
print(f"{out_info:>17}: {out_val}, time:{mean_time:.8f}ms")
67+
if show_all: print(out[0, 0, 0, :])
68+
return out.clone(), mean_time
69+
7070

71+
print("-" * 80)
72+
# batch_size, n_head, seq_len, head_dim (B,nh,N,d)
73+
B, nh, N, d = 16, 12, 64, 64
74+
q = torch.randn(B, nh, N, d).float().cuda().contiguous()
75+
k = torch.randn(B, nh, N, d).float().cuda().contiguous()
76+
v = torch.randn(B, nh, N, d).float().cuda().contiguous()
77+
o = torch.randn(B, nh, N, d).float().cuda().contiguous()
78+
q.requires_grad = False
79+
k.requires_grad = False
80+
v.requires_grad = False
81+
o.requires_grad = False
82+
run_benchmark(lib.flash_attn_1_fwd_f32, q, k, v, "fa1fwdf32")
83+
run_benchmark(lib.flash_attn_1_fwd_f32_v2, q, k, v, "fa1fwdf32(v2)", o)
84+
run_benchmark(manual_attn, q, k, v, "attnf32_th")
85+
print("-" * 80)

0 commit comments

Comments
 (0)