55包含以下内容:
66
77- [X] flash_attn_1_fwd_f32_kernel
8- - [ ] flash_attn_2_fwd_f32_kernel
9- - [ ] flash_attn_2_fwd_f16_kernel
10- - [x] flash_attn_2_fwd_f16_mma_m16n8k16_kernel
8+ - [x] flash_attn_2_fwd_f16_mma_m16n8k16_kernel (ldmatrix + MMA)
119- [X] PyTorch bindings
1210
11+ 本仓库FlashAttention仅用于学习CUDA编程,考虑性能最优请使用FlashAttention官方版本:[ flash-attention] ( https://github.com/Dao-AILab/flash-attention )
12+
1313### 运行测试
1414``` bash
1515# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
@@ -18,13 +18,174 @@ python3 flash_attn.py
1818```
1919日志如下:
2020``` bash
21- --------------------------------------------------------------------------------
22- out_fa1f32: [-0.07522935, -0.06757538, -0.30396557], time:0.77749610ms
23- out_fa1f32(v2): [-0.07522935, -0.06757538, -0.30396557], time:0.77480674ms
24- out_attnf32_th: [-0.07522935, -0.06757541, -0.30396566], time:0.05429983ms
25- --------------------------------------------------------------------------------
26- out_fa2mmaf16: [-0.07525635, -0.06762695, -0.30395508], time:0.01422763ms
27- out_fa2mmaf16(v2): [-0.07525635, -0.06762695, -0.30395508], time:0.01072645ms
28- out_attnf16_th: [-0.07525635, -0.06756592, -0.30395508], time:0.05322218ms
29- --------------------------------------------------------------------------------
21+ ----------------------------------------------------------------------------------------------------
22+ B: batch_size, H: n_head, N: seq_len, D: head_dim
23+ ----------------------------------------------------------------------------------------------------
24+ B=8, H=8, N=256, D=64
25+ out_FA1f32: [' 0.01037013 ' , ' -0.09995531 ' , ' 0.09193697 ' ], time:9.288564ms
26+ out_f32_th(naive): [' 0.01037012 ' , ' -0.09995528 ' , ' 0.09193695 ' ], time:0.086453ms
27+ ----------------------------------------------------------------------------------------------------
28+ out_FA2MMAf16: [' 0.01031494 ' , ' -0.09997559 ' , ' 0.09197998 ' ], time:0.047593ms
29+ out_f16_th(naive): [' 0.01040649 ' , ' -0.10003662 ' , ' 0.09197998 ' ], time:0.053408ms
30+ ----------------------------------------------------------------------------------------------------
31+ ----------------------------------------------------------------------------------------------------
32+ B=8, H=8, N=256, D=128
33+ ----------------------------------------------------------------------------------------------------
34+ out_FA2MMAf16: [' 0.15332031 ' , ' 0.15917969 ' , ' 0.07592773 ' ], time:0.091217ms
35+ out_f16_th(naive): [' 0.15368652 ' , ' 0.15905762 ' , ' 0.07580566 ' ], time:0.052757ms
36+ ----------------------------------------------------------------------------------------------------
37+ ----------------------------------------------------------------------------------------------------
38+ B=8, H=8, N=512, D=64
39+ out_FA1f32: [' 0.01696955 ' , ' -0.05399467 ' , ' -0.03177956 ' ], time:37.062004ms
40+ out_f32_th(naive): [' 0.01696953 ' , ' -0.05399465 ' , ' -0.03177955 ' ], time:0.471001ms
41+ ----------------------------------------------------------------------------------------------------
42+ out_FA2MMAf16: [' 0.01699829 ' , ' -0.0539856 ' , ' -0.0317688 ' ], time:0.168507ms
43+ out_f16_th(naive): [' 0.01699829 ' , ' -0.0539856 ' , ' -0.03173828 ' ], time:0.132778ms
44+ ----------------------------------------------------------------------------------------------------
45+ ----------------------------------------------------------------------------------------------------
46+ B=8, H=8, N=512, D=128
47+ ----------------------------------------------------------------------------------------------------
48+ out_FA2MMAf16: [' 0.06872559 ' , ' -0.07714844 ' , ' 0.04348755 ' ], time:0.326455ms
49+ out_f16_th(naive): [' 0.06872559 ' , ' -0.07720947 ' , ' 0.04345703 ' ], time:0.152197ms
50+ ----------------------------------------------------------------------------------------------------
51+ ----------------------------------------------------------------------------------------------------
52+ B=8, H=8, N=1024, D=64
53+ out_FA1f32: [' -0.04256601 ' , ' 0.0555016 ' , ' 0.05054659 ' ], time:148.082373ms
54+ out_f32_th(naive): [' -0.04256602 ' , ' 0.05550159 ' , ' 0.05054657 ' ], time:2.673364ms
55+ ----------------------------------------------------------------------------------------------------
56+ out_FA2MMAf16: [' -0.0425415 ' , ' 0.05551147 ' , ' 0.05053711 ' ], time:0.633800ms
57+ out_f16_th(naive): [' -0.0425415 ' , ' 0.05545044 ' , ' 0.05053711 ' ], time:1.276960ms
58+ ----------------------------------------------------------------------------------------------------
59+ ----------------------------------------------------------------------------------------------------
60+ B=8, H=8, N=1024, D=128
61+ ----------------------------------------------------------------------------------------------------
62+ out_FA2MMAf16: [' -0.00053024 ' , ' 0.04940796 ' , ' -0.01649475 ' ], time:1.235073ms
63+ out_f16_th(naive): [' -0.00051165 ' , ' 0.04946899 ' , ' -0.01644897 ' ], time:1.371036ms
64+ ----------------------------------------------------------------------------------------------------
65+ ----------------------------------------------------------------------------------------------------
66+ B=8, H=16, N=256, D=64
67+ out_FA1f32: [' 0.06706338 ' , ' -0.01847678 ' , ' -0.02532079 ' ], time:9.592953ms
68+ out_f32_th(naive): [' 0.0670634 ' , ' -0.01847675 ' , ' -0.02532081 ' ], time:0.150659ms
69+ ----------------------------------------------------------------------------------------------------
70+ out_FA2MMAf16: [' 0.06719971 ' , ' -0.01847839 ' , ' -0.02529907 ' ], time:0.060866ms
71+ out_f16_th(naive): [' 0.06713867 ' , ' -0.01846313 ' , ' -0.0252533 ' ], time:0.063777ms
72+ ----------------------------------------------------------------------------------------------------
73+ ----------------------------------------------------------------------------------------------------
74+ B=8, H=16, N=256, D=128
75+ ----------------------------------------------------------------------------------------------------
76+ out_FA2MMAf16: [' -0.05142212 ' , ' 0.03041077 ' , ' -0.08868408 ' ], time:0.132723ms
77+ out_f16_th(naive): [' -0.05151367 ' , ' 0.03018188 ' , ' -0.08911133 ' ], time:0.079043ms
78+ ----------------------------------------------------------------------------------------------------
79+ ----------------------------------------------------------------------------------------------------
80+ B=8, H=16, N=512, D=64
81+ out_FA1f32: [' -0.03446965 ' , ' 0.05762016 ' , ' 0.07836776 ' ], time:38.253429ms
82+ out_f32_th(naive): [' -0.03446964 ' , ' 0.05762014 ' , ' 0.07836778 ' ], time:1.357274ms
83+ ----------------------------------------------------------------------------------------------------
84+ out_FA2MMAf16: [' -0.03445435 ' , ' 0.05758667 ' , ' 0.07836914 ' ], time:0.218937ms
85+ out_f16_th(naive): [' -0.03445435 ' , ' 0.05758667 ' , ' 0.07830811 ' ], time:0.500908ms
86+ ----------------------------------------------------------------------------------------------------
87+ ----------------------------------------------------------------------------------------------------
88+ B=8, H=16, N=512, D=128
89+ ----------------------------------------------------------------------------------------------------
90+ out_FA2MMAf16: [' -0.00230026 ' , ' -0.05194092 ' , ' 0.0164032 ' ], time:0.493281ms
91+ out_f16_th(naive): [' -0.00205803 ' , ' -0.05209351 ' , ' 0.01664734 ' ], time:0.568807ms
92+ ----------------------------------------------------------------------------------------------------
93+ ----------------------------------------------------------------------------------------------------
94+ B=8, H=16, N=1024, D=64
95+ out_FA1f32: [' 0.02074369 ' , ' -0.01090947 ' , ' -0.01393144 ' ], time:152.446897ms
96+ out_f32_th(naive): [' 0.02074368 ' , ' -0.01090949 ' , ' -0.01393143 ' ], time:5.296123ms
97+ ----------------------------------------------------------------------------------------------------
98+ out_FA2MMAf16: [' 0.02073669 ' , ' -0.01097107 ' , ' -0.01395416 ' ], time:0.834603ms
99+ out_f16_th(naive): [' 0.02073669 ' , ' -0.01092529 ' , ' -0.01390839 ' ], time:2.576745ms
100+ ----------------------------------------------------------------------------------------------------
101+ ----------------------------------------------------------------------------------------------------
102+ B=8, H=16, N=1024, D=128
103+ ----------------------------------------------------------------------------------------------------
104+ out_FA2MMAf16: [' 0.08306885 ' , ' 0.03659058 ' , ' 0.04852295 ' ], time:1.907628ms
105+ out_f16_th(naive): [' 0.08319092 ' , ' 0.03668213 ' , ' 0.04858398 ' ], time:2.696407ms
106+ ----------------------------------------------------------------------------------------------------
107+ ----------------------------------------------------------------------------------------------------
108+ B=16, H=8, N=256, D=64
109+ out_FA1f32: [' 0.09634054 ' , ' -0.02606717 ' , ' 0.13369624 ' ], time:9.618666ms
110+ out_f32_th(naive): [' 0.09634058 ' , ' -0.02606717 ' , ' 0.13369617 ' ], time:0.147052ms
111+ ----------------------------------------------------------------------------------------------------
112+ out_FA2MMAf16: [' 0.09649658 ' , ' -0.02606201 ' , ' 0.13366699 ' ], time:0.060964ms
113+ out_f16_th(naive): [' 0.09631348 ' , ' -0.02613831 ' , ' 0.13366699 ' ], time:0.063334ms
114+ ----------------------------------------------------------------------------------------------------
115+ ----------------------------------------------------------------------------------------------------
116+ B=16, H=8, N=256, D=128
117+ ----------------------------------------------------------------------------------------------------
118+ out_FA2MMAf16: [' -0.0680542 ' , ' 0.18212891 ' , ' 0.09741211 ' ], time:0.132513ms
119+ out_f16_th(naive): [' -0.0680542 ' , ' 0.18212891 ' , ' 0.09747314 ' ], time:0.079212ms
120+ ----------------------------------------------------------------------------------------------------
121+ ----------------------------------------------------------------------------------------------------
122+ B=16, H=8, N=512, D=64
123+ out_FA1f32: [' 0.06110233 ' , ' -0.03080001 ' , ' 0.06487844 ' ], time:38.171313ms
124+ out_f32_th(naive): [' 0.06110234 ' , ' -0.0308 ' , ' 0.06487839 ' ], time:1.358862ms
125+ ----------------------------------------------------------------------------------------------------
126+ out_FA2MMAf16: [' 0.06112671 ' , ' -0.03077698 ' , ' 0.06488037 ' ], time:0.218849ms
127+ out_f16_th(naive): [' 0.06109619 ' , ' -0.03079224 ' , ' 0.06488037 ' ], time:0.497117ms
128+ ----------------------------------------------------------------------------------------------------
129+ ----------------------------------------------------------------------------------------------------
130+ B=16, H=8, N=512, D=128
131+ ----------------------------------------------------------------------------------------------------
132+ out_FA2MMAf16: [' -0.00991058 ' , ' -0.18884277 ' , ' -0.04980469 ' ], time:0.493472ms
133+ out_f16_th(naive): [' -0.0098877 ' , ' -0.18884277 ' , ' -0.04980469 ' ], time:0.573759ms
134+ ----------------------------------------------------------------------------------------------------
135+ ----------------------------------------------------------------------------------------------------
136+ B=16, H=8, N=1024, D=64
137+ out_FA1f32: [' -0.01831236 ' , ' -0.07696866 ' , ' -0.04614653 ' ], time:152.500360ms
138+ out_f32_th(naive): [' -0.01831233 ' , ' -0.07696865 ' , ' -0.04614652 ' ], time:5.295737ms
139+ ----------------------------------------------------------------------------------------------------
140+ out_FA2MMAf16: [' -0.01831055 ' , ' -0.07696533 ' , ' -0.04614258 ' ], time:0.834262ms
141+ out_f16_th(naive): [' -0.01826477 ' , ' -0.0769043 ' , ' -0.04614258 ' ], time:2.576706ms
142+ ----------------------------------------------------------------------------------------------------
143+ ----------------------------------------------------------------------------------------------------
144+ B=16, H=8, N=1024, D=128
145+ ----------------------------------------------------------------------------------------------------
146+ out_FA2MMAf16: [' 0.04501343 ' , ' 0.07751465 ' , ' -0.01131439 ' ], time:1.907537ms
147+ out_f16_th(naive): [' 0.04501343 ' , ' 0.07745361 ' , ' -0.01132965 ' ], time:2.697947ms
148+ ----------------------------------------------------------------------------------------------------
149+ ----------------------------------------------------------------------------------------------------
150+ B=16, H=16, N=256, D=64
151+ out_FA1f32: [' 0.05493443 ' , ' 0.03093347 ' , ' -0.05244123 ' ], time:12.086096ms
152+ out_f32_th(naive): [' 0.05493441 ' , ' 0.03093351 ' , ' -0.05244119 ' ], time:0.518868ms
153+ ----------------------------------------------------------------------------------------------------
154+ out_FA2MMAf16: [' 0.05496216 ' , ' 0.03089905 ' , ' -0.05227661 ' ], time:0.083928ms
155+ out_f16_th(naive): [' 0.05487061 ' , ' 0.03102112 ' , ' -0.05239868 ' ], time:0.133991ms
156+ ----------------------------------------------------------------------------------------------------
157+ ----------------------------------------------------------------------------------------------------
158+ B=16, H=16, N=256, D=128
159+ ----------------------------------------------------------------------------------------------------
160+ out_FA2MMAf16: [' -0.03808594 ' , ' -0.19189453 ' , ' 0.00264549 ' ], time:0.192747ms
161+ out_f16_th(naive): [' -0.03778076 ' , ' -0.19189453 ' , ' 0.00281334 ' ], time:0.178058ms
162+ ----------------------------------------------------------------------------------------------------
163+ ----------------------------------------------------------------------------------------------------
164+ B=16, H=16, N=512, D=64
165+ out_FA1f32: [' 0.02739076 ' , ' 0.01203587 ' , ' 0.09457675 ' ], time:48.142586ms
166+ out_f32_th(naive): [' 0.02739077 ' , ' 0.01203588 ' , ' 0.09457672 ' ], time:2.749476ms
167+ ----------------------------------------------------------------------------------------------------
168+ out_FA2MMAf16: [' 0.02740479 ' , ' 0.01203918 ' , ' 0.09454346 ' ], time:0.291946ms
169+ out_f16_th(naive): [' 0.02740479 ' , ' 0.01203156 ' , ' 0.09460449 ' ], time:1.350477ms
170+ ----------------------------------------------------------------------------------------------------
171+ ----------------------------------------------------------------------------------------------------
172+ B=16, H=16, N=512, D=128
173+ ----------------------------------------------------------------------------------------------------
174+ out_FA2MMAf16: [' -0.06494141 ' , ' -0.06427002 ' , ' -0.04528809 ' ], time:0.690589ms
175+ out_f16_th(naive): [' -0.06500244 ' , ' -0.06427002 ' , ' -0.04519653 ' ], time:1.470513ms
176+ ----------------------------------------------------------------------------------------------------
177+ ----------------------------------------------------------------------------------------------------
178+ B=16, H=16, N=1024, D=64
179+ out_FA1f32: [' -0.02254915 ' , ' 0.00821745 ' , ' 0.09361463 ' ], time:196.162612ms
180+ out_f32_th(naive): [' -0.02254917 ' , ' 0.00821746 ' , ' 0.09361461 ' ], time:10.451190ms
181+ ----------------------------------------------------------------------------------------------------
182+ out_FA2MMAf16: [' -0.02252197 ' , ' 0.00821686 ' , ' 0.09368896 ' ], time:1.106799ms
183+ out_f16_th(naive): [' -0.02255249 ' , ' 0.00818634 ' , ' 0.09368896 ' ], time:5.125363ms
184+ ----------------------------------------------------------------------------------------------------
185+ ----------------------------------------------------------------------------------------------------
186+ B=16, H=16, N=1024, D=128
187+ ----------------------------------------------------------------------------------------------------
188+ out_FA2MMAf16: [' -0.07330322 ' , ' -0.06152344 ' , ' 0.00090456 ' ], time:3.174434ms
189+ out_f16_th(naive): [' -0.07336426 ' , ' -0.06149292 ' , ' 0.00105381 ' ], time:5.335908ms
190+ ----------------------------------------------------------------------------------------------------
30191```
0 commit comments