Skip to content

Commit da0b939

Browse files
authored
[FlashAttention] Refactor FlashAttention PyTorch bindings (#55)
* Update flash_attn.py * Update README.md * Update flash_attn.py * Rename flash_attn_1_fwd_f32.cu to flash_attn_1.cu * Rename flash_attn_2_fwd_f16_mma_m16n8k16.cu to flash_attn_2_mma.cu * Rename flash_attn_2_fwd_f32.cu to flash_attn_2.cu * Update flash_attn_1.cu * Update flash_attn_1.cu * Update flash_attn_2_mma.cu * Update flash_attn.cc * Update flash_attn.py * Update README.md * Update flash_attn.py * Rename flash_attn_1.cu to flash_attn.cu * Rename flash_attn_2_mma.cu to flash_attn_mma.cu * Delete flash-attn/flash_attn_2.cu * Update flash_attn.py * Update README.md * Update README.md * Update README.md
1 parent 7cf1879 commit da0b939

File tree

7 files changed

+234
-369
lines changed

7 files changed

+234
-369
lines changed

README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
100100
| ✔️ [sgemm_t_8x8_sliced_k_..._bcf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
101101
| ✔️ [sgemm_t_8x8_sliced_k_..._dbuf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
102-
| ✔️ [hgemm_naive_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
102+
| ✔️ [hgemm_naive_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️|
103103
| ✔️ [hgemm_sliced_k_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
104104
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
105105
| ✔️ [hgemm_t_8x8_sliced_k_f16x4_pack](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
@@ -111,10 +111,8 @@
111111
| ✔️ [hgemv_k32_f16](./hgemv/hgemv.cu)|f16|f16|[link](./hgemv/)|⭐️⭐️⭐️|
112112
| ✔️ [hgemv_k128_f16x4](./hgemv/hgemv.cu)|f16|f16|[link](./hgemv/)|⭐️⭐️⭐️|
113113
| ✔️ [hgemv_k16_f16](./hgemv/hgemv.cu)|f16|f16|[link](./hgemv/)|⭐️⭐️⭐️|
114-
| ✔️ [flash_attn_1_fwd_f32](./flash-attn/flash_attn_1_fwd_f32.cu)|f32|f32|[link](./flash-attn)|⭐️⭐️⭐️|
115-
|[flash_attn_2_fwd_f32](./flash-attn/flash_attn_2_fwd_f32.cu)|f32|f32|[link](./flash-attn)|⭐️⭐️⭐️|
116-
|[flash_attn_2_fwd_f16](./flash-attn/flash_attn_2_fwd_f32.cu)|f16|f16|[link](./flash-attn)|⭐️⭐️⭐️|
117-
| ✔️ [flash_attn_2_fwd_f16_m16n8k16*](./flash-attn/flash_attn_2_fwd_f16_mma_m16n8k16.cu)|f16|f16|[link](./flash-attn)|⭐️⭐️⭐️|
114+
| ✔️ [flash_attn_1_fwd_f32](./flash-attn/flash_attn.cu)|f32|f32|[link](./flash-attn)|⭐️⭐️⭐️|
115+
| ✔️ [flash_attn_2_fwd_f16_m16n8k16*](./flash-attn/flash_attn_mma.cu)|f16|f16|[link](./flash-attn)|⭐️⭐️⭐️|
118116
| ✔️ [hard_nms cpp only](./nms/nms.cc)|f32|/|/|⭐️|
119117
| ✔️ [notes v1(deprecated)](./notes-v1.cu)|f32|f32|/|⭐️|
120118

flash-attn/README.md

Lines changed: 173 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
包含以下内容:
66

77
- [X] flash_attn_1_fwd_f32_kernel
8-
- [ ] flash_attn_2_fwd_f32_kernel
9-
- [ ] flash_attn_2_fwd_f16_kernel
10-
- [x] flash_attn_2_fwd_f16_mma_m16n8k16_kernel
8+
- [x] flash_attn_2_fwd_f16_mma_m16n8k16_kernel (ldmatrix + MMA)
119
- [X] PyTorch bindings
1210

11+
本仓库FlashAttention仅用于学习CUDA编程,考虑性能最优请使用FlashAttention官方版本:[flash-attention](https://github.com/Dao-AILab/flash-attention)
12+
1313
### 运行测试
1414
```bash
1515
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
@@ -18,13 +18,174 @@ python3 flash_attn.py
1818
```
1919
日志如下:
2020
```bash
21-
--------------------------------------------------------------------------------
22-
out_fa1f32: [-0.07522935, -0.06757538, -0.30396557], time:0.77749610ms
23-
out_fa1f32(v2): [-0.07522935, -0.06757538, -0.30396557], time:0.77480674ms
24-
out_attnf32_th: [-0.07522935, -0.06757541, -0.30396566], time:0.05429983ms
25-
--------------------------------------------------------------------------------
26-
out_fa2mmaf16: [-0.07525635, -0.06762695, -0.30395508], time:0.01422763ms
27-
out_fa2mmaf16(v2): [-0.07525635, -0.06762695, -0.30395508], time:0.01072645ms
28-
out_attnf16_th: [-0.07525635, -0.06756592, -0.30395508], time:0.05322218ms
29-
--------------------------------------------------------------------------------
21+
----------------------------------------------------------------------------------------------------
22+
B: batch_size, H: n_head, N: seq_len, D: head_dim
23+
----------------------------------------------------------------------------------------------------
24+
B=8, H=8, N=256, D=64
25+
out_FA1f32: ['0.01037013 ', '-0.09995531 ', '0.09193697 '], time:9.288564ms
26+
out_f32_th(naive): ['0.01037012 ', '-0.09995528 ', '0.09193695 '], time:0.086453ms
27+
----------------------------------------------------------------------------------------------------
28+
out_FA2MMAf16: ['0.01031494 ', '-0.09997559 ', '0.09197998 '], time:0.047593ms
29+
out_f16_th(naive): ['0.01040649 ', '-0.10003662 ', '0.09197998 '], time:0.053408ms
30+
----------------------------------------------------------------------------------------------------
31+
----------------------------------------------------------------------------------------------------
32+
B=8, H=8, N=256, D=128
33+
----------------------------------------------------------------------------------------------------
34+
out_FA2MMAf16: ['0.15332031 ', '0.15917969 ', '0.07592773 '], time:0.091217ms
35+
out_f16_th(naive): ['0.15368652 ', '0.15905762 ', '0.07580566 '], time:0.052757ms
36+
----------------------------------------------------------------------------------------------------
37+
----------------------------------------------------------------------------------------------------
38+
B=8, H=8, N=512, D=64
39+
out_FA1f32: ['0.01696955 ', '-0.05399467 ', '-0.03177956 '], time:37.062004ms
40+
out_f32_th(naive): ['0.01696953 ', '-0.05399465 ', '-0.03177955 '], time:0.471001ms
41+
----------------------------------------------------------------------------------------------------
42+
out_FA2MMAf16: ['0.01699829 ', '-0.0539856 ', '-0.0317688 '], time:0.168507ms
43+
out_f16_th(naive): ['0.01699829 ', '-0.0539856 ', '-0.03173828 '], time:0.132778ms
44+
----------------------------------------------------------------------------------------------------
45+
----------------------------------------------------------------------------------------------------
46+
B=8, H=8, N=512, D=128
47+
----------------------------------------------------------------------------------------------------
48+
out_FA2MMAf16: ['0.06872559 ', '-0.07714844 ', '0.04348755 '], time:0.326455ms
49+
out_f16_th(naive): ['0.06872559 ', '-0.07720947 ', '0.04345703 '], time:0.152197ms
50+
----------------------------------------------------------------------------------------------------
51+
----------------------------------------------------------------------------------------------------
52+
B=8, H=8, N=1024, D=64
53+
out_FA1f32: ['-0.04256601 ', '0.0555016 ', '0.05054659 '], time:148.082373ms
54+
out_f32_th(naive): ['-0.04256602 ', '0.05550159 ', '0.05054657 '], time:2.673364ms
55+
----------------------------------------------------------------------------------------------------
56+
out_FA2MMAf16: ['-0.0425415 ', '0.05551147 ', '0.05053711 '], time:0.633800ms
57+
out_f16_th(naive): ['-0.0425415 ', '0.05545044 ', '0.05053711 '], time:1.276960ms
58+
----------------------------------------------------------------------------------------------------
59+
----------------------------------------------------------------------------------------------------
60+
B=8, H=8, N=1024, D=128
61+
----------------------------------------------------------------------------------------------------
62+
out_FA2MMAf16: ['-0.00053024 ', '0.04940796 ', '-0.01649475 '], time:1.235073ms
63+
out_f16_th(naive): ['-0.00051165 ', '0.04946899 ', '-0.01644897 '], time:1.371036ms
64+
----------------------------------------------------------------------------------------------------
65+
----------------------------------------------------------------------------------------------------
66+
B=8, H=16, N=256, D=64
67+
out_FA1f32: ['0.06706338 ', '-0.01847678 ', '-0.02532079 '], time:9.592953ms
68+
out_f32_th(naive): ['0.0670634 ', '-0.01847675 ', '-0.02532081 '], time:0.150659ms
69+
----------------------------------------------------------------------------------------------------
70+
out_FA2MMAf16: ['0.06719971 ', '-0.01847839 ', '-0.02529907 '], time:0.060866ms
71+
out_f16_th(naive): ['0.06713867 ', '-0.01846313 ', '-0.0252533 '], time:0.063777ms
72+
----------------------------------------------------------------------------------------------------
73+
----------------------------------------------------------------------------------------------------
74+
B=8, H=16, N=256, D=128
75+
----------------------------------------------------------------------------------------------------
76+
out_FA2MMAf16: ['-0.05142212 ', '0.03041077 ', '-0.08868408 '], time:0.132723ms
77+
out_f16_th(naive): ['-0.05151367 ', '0.03018188 ', '-0.08911133 '], time:0.079043ms
78+
----------------------------------------------------------------------------------------------------
79+
----------------------------------------------------------------------------------------------------
80+
B=8, H=16, N=512, D=64
81+
out_FA1f32: ['-0.03446965 ', '0.05762016 ', '0.07836776 '], time:38.253429ms
82+
out_f32_th(naive): ['-0.03446964 ', '0.05762014 ', '0.07836778 '], time:1.357274ms
83+
----------------------------------------------------------------------------------------------------
84+
out_FA2MMAf16: ['-0.03445435 ', '0.05758667 ', '0.07836914 '], time:0.218937ms
85+
out_f16_th(naive): ['-0.03445435 ', '0.05758667 ', '0.07830811 '], time:0.500908ms
86+
----------------------------------------------------------------------------------------------------
87+
----------------------------------------------------------------------------------------------------
88+
B=8, H=16, N=512, D=128
89+
----------------------------------------------------------------------------------------------------
90+
out_FA2MMAf16: ['-0.00230026 ', '-0.05194092 ', '0.0164032 '], time:0.493281ms
91+
out_f16_th(naive): ['-0.00205803 ', '-0.05209351 ', '0.01664734 '], time:0.568807ms
92+
----------------------------------------------------------------------------------------------------
93+
----------------------------------------------------------------------------------------------------
94+
B=8, H=16, N=1024, D=64
95+
out_FA1f32: ['0.02074369 ', '-0.01090947 ', '-0.01393144 '], time:152.446897ms
96+
out_f32_th(naive): ['0.02074368 ', '-0.01090949 ', '-0.01393143 '], time:5.296123ms
97+
----------------------------------------------------------------------------------------------------
98+
out_FA2MMAf16: ['0.02073669 ', '-0.01097107 ', '-0.01395416 '], time:0.834603ms
99+
out_f16_th(naive): ['0.02073669 ', '-0.01092529 ', '-0.01390839 '], time:2.576745ms
100+
----------------------------------------------------------------------------------------------------
101+
----------------------------------------------------------------------------------------------------
102+
B=8, H=16, N=1024, D=128
103+
----------------------------------------------------------------------------------------------------
104+
out_FA2MMAf16: ['0.08306885 ', '0.03659058 ', '0.04852295 '], time:1.907628ms
105+
out_f16_th(naive): ['0.08319092 ', '0.03668213 ', '0.04858398 '], time:2.696407ms
106+
----------------------------------------------------------------------------------------------------
107+
----------------------------------------------------------------------------------------------------
108+
B=16, H=8, N=256, D=64
109+
out_FA1f32: ['0.09634054 ', '-0.02606717 ', '0.13369624 '], time:9.618666ms
110+
out_f32_th(naive): ['0.09634058 ', '-0.02606717 ', '0.13369617 '], time:0.147052ms
111+
----------------------------------------------------------------------------------------------------
112+
out_FA2MMAf16: ['0.09649658 ', '-0.02606201 ', '0.13366699 '], time:0.060964ms
113+
out_f16_th(naive): ['0.09631348 ', '-0.02613831 ', '0.13366699 '], time:0.063334ms
114+
----------------------------------------------------------------------------------------------------
115+
----------------------------------------------------------------------------------------------------
116+
B=16, H=8, N=256, D=128
117+
----------------------------------------------------------------------------------------------------
118+
out_FA2MMAf16: ['-0.0680542 ', '0.18212891 ', '0.09741211 '], time:0.132513ms
119+
out_f16_th(naive): ['-0.0680542 ', '0.18212891 ', '0.09747314 '], time:0.079212ms
120+
----------------------------------------------------------------------------------------------------
121+
----------------------------------------------------------------------------------------------------
122+
B=16, H=8, N=512, D=64
123+
out_FA1f32: ['0.06110233 ', '-0.03080001 ', '0.06487844 '], time:38.171313ms
124+
out_f32_th(naive): ['0.06110234 ', '-0.0308 ', '0.06487839 '], time:1.358862ms
125+
----------------------------------------------------------------------------------------------------
126+
out_FA2MMAf16: ['0.06112671 ', '-0.03077698 ', '0.06488037 '], time:0.218849ms
127+
out_f16_th(naive): ['0.06109619 ', '-0.03079224 ', '0.06488037 '], time:0.497117ms
128+
----------------------------------------------------------------------------------------------------
129+
----------------------------------------------------------------------------------------------------
130+
B=16, H=8, N=512, D=128
131+
----------------------------------------------------------------------------------------------------
132+
out_FA2MMAf16: ['-0.00991058 ', '-0.18884277 ', '-0.04980469 '], time:0.493472ms
133+
out_f16_th(naive): ['-0.0098877 ', '-0.18884277 ', '-0.04980469 '], time:0.573759ms
134+
----------------------------------------------------------------------------------------------------
135+
----------------------------------------------------------------------------------------------------
136+
B=16, H=8, N=1024, D=64
137+
out_FA1f32: ['-0.01831236 ', '-0.07696866 ', '-0.04614653 '], time:152.500360ms
138+
out_f32_th(naive): ['-0.01831233 ', '-0.07696865 ', '-0.04614652 '], time:5.295737ms
139+
----------------------------------------------------------------------------------------------------
140+
out_FA2MMAf16: ['-0.01831055 ', '-0.07696533 ', '-0.04614258 '], time:0.834262ms
141+
out_f16_th(naive): ['-0.01826477 ', '-0.0769043 ', '-0.04614258 '], time:2.576706ms
142+
----------------------------------------------------------------------------------------------------
143+
----------------------------------------------------------------------------------------------------
144+
B=16, H=8, N=1024, D=128
145+
----------------------------------------------------------------------------------------------------
146+
out_FA2MMAf16: ['0.04501343 ', '0.07751465 ', '-0.01131439 '], time:1.907537ms
147+
out_f16_th(naive): ['0.04501343 ', '0.07745361 ', '-0.01132965 '], time:2.697947ms
148+
----------------------------------------------------------------------------------------------------
149+
----------------------------------------------------------------------------------------------------
150+
B=16, H=16, N=256, D=64
151+
out_FA1f32: ['0.05493443 ', '0.03093347 ', '-0.05244123 '], time:12.086096ms
152+
out_f32_th(naive): ['0.05493441 ', '0.03093351 ', '-0.05244119 '], time:0.518868ms
153+
----------------------------------------------------------------------------------------------------
154+
out_FA2MMAf16: ['0.05496216 ', '0.03089905 ', '-0.05227661 '], time:0.083928ms
155+
out_f16_th(naive): ['0.05487061 ', '0.03102112 ', '-0.05239868 '], time:0.133991ms
156+
----------------------------------------------------------------------------------------------------
157+
----------------------------------------------------------------------------------------------------
158+
B=16, H=16, N=256, D=128
159+
----------------------------------------------------------------------------------------------------
160+
out_FA2MMAf16: ['-0.03808594 ', '-0.19189453 ', '0.00264549 '], time:0.192747ms
161+
out_f16_th(naive): ['-0.03778076 ', '-0.19189453 ', '0.00281334 '], time:0.178058ms
162+
----------------------------------------------------------------------------------------------------
163+
----------------------------------------------------------------------------------------------------
164+
B=16, H=16, N=512, D=64
165+
out_FA1f32: ['0.02739076 ', '0.01203587 ', '0.09457675 '], time:48.142586ms
166+
out_f32_th(naive): ['0.02739077 ', '0.01203588 ', '0.09457672 '], time:2.749476ms
167+
----------------------------------------------------------------------------------------------------
168+
out_FA2MMAf16: ['0.02740479 ', '0.01203918 ', '0.09454346 '], time:0.291946ms
169+
out_f16_th(naive): ['0.02740479 ', '0.01203156 ', '0.09460449 '], time:1.350477ms
170+
----------------------------------------------------------------------------------------------------
171+
----------------------------------------------------------------------------------------------------
172+
B=16, H=16, N=512, D=128
173+
----------------------------------------------------------------------------------------------------
174+
out_FA2MMAf16: ['-0.06494141 ', '-0.06427002 ', '-0.04528809 '], time:0.690589ms
175+
out_f16_th(naive): ['-0.06500244 ', '-0.06427002 ', '-0.04519653 '], time:1.470513ms
176+
----------------------------------------------------------------------------------------------------
177+
----------------------------------------------------------------------------------------------------
178+
B=16, H=16, N=1024, D=64
179+
out_FA1f32: ['-0.02254915 ', '0.00821745 ', '0.09361463 '], time:196.162612ms
180+
out_f32_th(naive): ['-0.02254917 ', '0.00821746 ', '0.09361461 '], time:10.451190ms
181+
----------------------------------------------------------------------------------------------------
182+
out_FA2MMAf16: ['-0.02252197 ', '0.00821686 ', '0.09368896 '], time:1.106799ms
183+
out_f16_th(naive): ['-0.02255249 ', '0.00818634 ', '0.09368896 '], time:5.125363ms
184+
----------------------------------------------------------------------------------------------------
185+
----------------------------------------------------------------------------------------------------
186+
B=16, H=16, N=1024, D=128
187+
----------------------------------------------------------------------------------------------------
188+
out_FA2MMAf16: ['-0.07330322 ', '-0.06152344 ', '0.00090456 '], time:3.174434ms
189+
out_f16_th(naive): ['-0.07336426 ', '-0.06149292 ', '0.00105381 '], time:5.335908ms
190+
----------------------------------------------------------------------------------------------------
30191
```

flash-attn/flash_attn.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,10 @@
55
#define TORCH_BINDING_COMMON_EXTENSION(func) \
66
m.def(STRINGFY(func), &func, STRINGFY(func));
77

8-
torch::Tensor flash_attn_1_fwd_f32(torch::Tensor Q, torch::Tensor K, torch::Tensor V);
9-
void flash_attn_1_fwd_f32_v2(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O);
10-
torch::Tensor flash_attn_2_fwd_f16_mma_m16n8k16(torch::Tensor Q, torch::Tensor K, torch::Tensor V);
11-
void flash_attn_2_fwd_f16_mma_m16n8k16_v2(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O);
8+
void flash_attn_1_fwd_f32(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O);
9+
void flash_attn_2_fwd_f16_mma_m16n8k16(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O);
1210

1311
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1412
TORCH_BINDING_COMMON_EXTENSION(flash_attn_1_fwd_f32)
15-
TORCH_BINDING_COMMON_EXTENSION(flash_attn_1_fwd_f32_v2)
1613
TORCH_BINDING_COMMON_EXTENSION(flash_attn_2_fwd_f16_mma_m16n8k16)
17-
TORCH_BINDING_COMMON_EXTENSION(flash_attn_2_fwd_f16_mma_m16n8k16_v2)
1814
}

0 commit comments

Comments
 (0)