@@ -36,8 +36,8 @@ def get_args():
36
36
parser .add_argument ("--no-rand-k" , '--no-rk' , action = "store_true" )
37
37
parser .add_argument ("--no-rand-v" , '--no-rv' , action = "store_true" )
38
38
parser .add_argument ("--no-rand-qkv" , '--no-rqkv' , action = "store_true" )
39
- parser .add_argument ("--naive" , action = "store_true" )
40
- parser .add_argument ("--sdpa" , action = "store_true" )
39
+ parser .add_argument ("--run-torch-unfused" , '--torch' , action = "store_true" )
40
+ parser .add_argument ("--run-torch- sdpa" , '--sdpa' , action = "store_true" )
41
41
parser .add_argument ("--check" , action = "store_true" )
42
42
parser .add_argument ("--show-all" , '--show' , action = "store_true" )
43
43
parser .add_argument ("--B" , type = int , default = None )
@@ -46,6 +46,7 @@ def get_args():
46
46
parser .add_argument ("--D" , type = int , default = None )
47
47
parser .add_argument ("--seed" , type = int , default = None )
48
48
parser .add_argument ("--debug" , action = "store_true" )
49
+ parser .add_argument ("--verbose" , '--v' , action = "store_true" )
49
50
parser .add_argument ("--warmup" , type = int , default = 2 )
50
51
parser .add_argument ("--iters" , type = int , default = 10 )
51
52
parser .add_argument ("--range-k" , '--gk' , action = "store_true" )
@@ -59,10 +60,10 @@ def get_args():
59
60
# Load the CUDA kernel as a python module
60
61
lib = load (name = 'flash_attn_lib' ,
61
62
sources = [
62
- './naive/flash_attn_cuda .cu' ,
63
- './mma/flash_attn_mma_naive .cu' ,
64
- './mma/flash_attn_mma_stage.cu' ,
65
- './pybind/flash_attn.cc' ],
63
+ './mma/flash_attn_mma_split_kv .cu' ,
64
+ './mma/flash_attn_mma_split_q .cu' ,
65
+ './pybind/flash_attn.cc'
66
+ ],
66
67
extra_cuda_cflags = [
67
68
"-O3" ,
68
69
"-U__CUDA_NO_HALF_OPERATORS__" ,
@@ -72,10 +73,43 @@ def get_args():
72
73
"--expt-relaxed-constexpr" ,
73
74
"--expt-extended-lambda" ,
74
75
"--use_fast_math" ,
76
+ "-Xptxas -v" ,
77
+ "-diag-suppress 177" ,
75
78
f"-I { project_dir } /kernels/flash-attn/utils" ,
76
79
"-DFLASH_ATTN_MMA_DEBUG" if args .debug else ""
77
80
],
78
- extra_cflags = ['-std=c++17' ])
81
+ extra_cflags = ['-std=c++17' ],
82
+ verbose = args .verbose )
83
+
84
+
85
+ def get_mha_tflops (B , H , N , D , T = 1.0 ):
86
+ # Q @ K^T FLOPs
87
+ flops_qk = B * H * N * N * (2 * D - 1 )
88
+
89
+ # Scaling FLOPs
90
+ flops_scaling = B * H * N * N
91
+
92
+ # Safe_Softmax FLOPs
93
+ flops_row_max = B * H * N * (N - 1 ) # row max
94
+ flops_subtract_max = B * H * N * N # sub max
95
+ flops_exp = B * H * N * N # pointwise exp
96
+ flops_row_sum = B * H * N * (N - 1 ) # row sum
97
+ flops_normalization = B * H * N * N # 归一化
98
+
99
+ flops_safe_softmax = flops_row_max + flops_subtract_max + flops_exp + flops_row_sum + flops_normalization
100
+
101
+ # P @ V FLOPs
102
+ flops_pv = B * H * N * D * (2 * N - 1 )
103
+
104
+ # Total FLOPs
105
+ total_flops = flops_qk + flops_scaling + flops_safe_softmax + flops_pv
106
+
107
+ # Convert to TFLOPS
108
+ # 1 TFLOPS = 10^12 FLOPS
109
+ # ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
110
+ tflops = total_flops * 1e-12 / (T )
111
+
112
+ return tflops
79
113
80
114
81
115
def run_benchmark (perf_func : callable ,
@@ -123,8 +157,14 @@ def run_benchmark(perf_func: callable,
123
157
out = perf_func (q , k , v )
124
158
torch .cuda .synchronize ()
125
159
end = time .time ()
160
+ total_secs = (end - start )
126
161
total_time = (end - start ) * 1000 # ms
127
162
mean_time = total_time / iters
163
+ mean_secs = total_secs / iters
164
+ B , H , N , D = q .size ()
165
+ if "flash" in tag :
166
+ B , N , H , D = q .size ()
167
+ TFLOPS = get_mha_tflops (B , H , N , D , mean_secs )
128
168
out_info = f"{ tag } "
129
169
out_val_first = out .flatten ()[:3 ].detach ().cpu ().numpy ().tolist ()
130
170
out_val_last = out .flatten ()[- 3 :].detach ().cpu ().numpy ().tolist ()
@@ -133,10 +173,11 @@ def run_benchmark(perf_func: callable,
133
173
out_val = out_val_first [:2 ]
134
174
out_val .append (out_val_last [- 1 ])
135
175
out_val = [f"{ v :<12} " for v in out_val ]
136
- print (f"{ out_info :>20 } : { out_val } , time:{ mean_time :.6f} ms" )
176
+ print (f"{ out_info :>25 } : { out_val } , time:{ mean_time :< .6f} ms, TFLOPS: { TFLOPS :<6.2f } " )
137
177
if show_all :
138
178
print (out )
139
179
time .sleep (0.05 )
180
+ torch .cuda .synchronize ()
140
181
return out .clone (), mean_time
141
182
142
183
@@ -159,18 +200,38 @@ def get_qkvo(B, H, N, D):
159
200
v = torch .ones (B , H , N , D , device = "cuda" , dtype = torch .half ).contiguous ()
160
201
161
202
o = torch .zeros (B , H , N , D , device = "cuda" , dtype = torch .half ).contiguous ()
203
+ tk = k .transpose (- 2 , - 1 ).contiguous ()
204
+ fq = q .transpose (1 , 2 ).contiguous ()
205
+ fk = k .transpose (1 , 2 ).contiguous ()
206
+ fv = v .transpose (1 , 2 ).contiguous ()
162
207
163
- return q , k , v , o
208
+ return q , k , v , o , tk , fq , fk , fv
164
209
165
210
166
211
# un-fused naive attn
167
- def naive_attn (q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ):
212
+ def unfused_standard_attn (q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ):
168
213
att = (q @ k .transpose (- 2 , - 1 ) * (1.0 / math .sqrt (k .size (- 1 ))))
169
214
att = F .softmax (att , dim = - 1 )
170
215
y = att @ v
171
216
return y
172
217
173
218
219
+ def check_all_close (out_flash : torch .Tensor , out_mma : torch .Tensor ,
220
+ tag : str = "out_mma" , show_all : bool = False ):
221
+ out_flash = out_flash .transpose (1 , 2 )
222
+ if show_all :
223
+ for i in range (int (N / 8 )):
224
+ if i < 4 :
225
+ print ("-" * 120 )
226
+ print (f"out_flash[:, :, { (i * 8 )} :{ (i + 1 )* 8 } , :]:\n " )
227
+ print (out_flash [:, :, (i * 8 ):(i + 1 )* 8 , :].float ())
228
+ print (f"{ tag } [:, :, { (i * 8 )} :{ (i + 1 )* 8 } , :]:\n " )
229
+ print (out_mma [:, :, (i * 8 ):(i + 1 )* 8 , :].float ())
230
+ print ("-" * 120 )
231
+ all_close = torch .allclose (out_flash .float (), out_mma .float (), atol = 1e-2 )
232
+ print (f"out_flash vs { tag } : { all_close } " )
233
+
234
+
174
235
Bs = [1 , 2 , 4 ] if not args .B else [args .B ]
175
236
Hs = [1 , 4 , 8 ] if not args .H else [args .H ]
176
237
Ns = [1024 , 2048 ] if not args .N else [args .N ]
@@ -180,42 +241,28 @@ def naive_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
180
241
181
242
seed = args .seed if args .seed else random .choice (range (10000 ))
182
243
set_rand_seed (seed )
183
- print ("-" * 100 )
184
- print (" " * 10 + f"B: batch_size, H: n_head, N: seq_len, D: head_dim, "
244
+ print ("-" * 120 )
245
+ print (" " * 20 + f"B: batch_size, H: n_head, N: seq_len, D: head_dim, "
185
246
f"seed: { seed } , Warmup: { args .warmup } , Iters: { args .iters } " )
186
247
187
248
for (B , H , N , D ) in BHNDs :
188
- print ("-" * 100 )
189
- print (" " * 25 + f"B={ B } , H={ H } , N={ N } , D={ D } , Warmup: { args .warmup } , Iters: { args .iters } " )
190
- q , k , v , o = get_qkvo (B , H , N , D )
191
- tk = k .transpose (- 2 , - 1 ).contiguous ()
192
- fq = q .transpose (1 , 2 ).contiguous ()
193
- fk = k .transpose (1 , 2 ).contiguous ()
194
- fv = v .transpose (1 , 2 ).contiguous ()
249
+ print ("-" * 120 )
250
+ print (" " * 30 + f"B={ B } , H={ H } , N={ N } , D={ D } , Warmup: { args .warmup } , Iters: { args .iters } " )
251
+ q , k , v , o , tk , fq , fk , fv = get_qkvo (B , H , N , D )
195
252
torch .cuda .synchronize ()
196
253
197
- if args .naive :
198
- out_naive , _ = run_benchmark (naive_attn , q , k , v , "naive(unfused)" )
199
-
200
- # using fp16 Tesor Core MMA instruction
201
- out_mma_naive , _ = run_benchmark (lib .flash_attn_mma_naive , q , k , v , "mma(naive)" , o )
202
- out_mma_stage1 , _ = run_benchmark (lib .flash_attn_mma_stages , q , tk , v , "mma(stage1)" , o , stages = 1 )
203
- out_mma_stage2 , _ = run_benchmark (lib .flash_attn_mma_stages , q , tk , v , "mma(stage2)" , o , stages = 2 )
204
- out_flash , _ = run_benchmark (flash_attn_func , fq , fk , fv , "(flash)" )
205
-
206
- if args .sdpa :
207
- out_sdpa , _ = run_benchmark (F .scaled_dot_product_attention , q , k , v , "(sdpa)" )
208
- print ("-" * 100 )
254
+ if args .run_torch_unfused :
255
+ out_unfused , _ = run_benchmark (unfused_standard_attn , q , k , v , "torch(unfused)" )
256
+ out_mma_split_kv1 , _ = run_benchmark (lib .flash_attn_mma_stages_split_kv , q , tk , v , "mma(split-kv+stage1)" , o , stages = 1 )
257
+ out_mma_split_kv2 , _ = run_benchmark (lib .flash_attn_mma_stages_split_kv , q , tk , v , "mma(split-kv+stage2)" , o , stages = 2 )
258
+ out_mma_split_q1 , _ = run_benchmark (lib .flash_attn_mma_stages_split_q , q , tk , v , "mma(split-q+stage1)" , o , stages = 1 )
259
+ out_mma_split_q2 , _ = run_benchmark (lib .flash_attn_mma_stages_split_q , q , tk , v , "mma(split-q+stage2)" , o , stages = 2 )
260
+ out_flash , _ = run_benchmark (flash_attn_func , fq , fk , fv , "(flash)" )
261
+ if args .run_torch_sdpa :
262
+ out_sdpa , _ = run_benchmark (F .scaled_dot_product_attention , q , k , v , "(sdpa)" )
263
+ print ("-" * 120 )
209
264
210
265
torch .cuda .synchronize ()
211
266
if args .check :
212
- out_flash = out_flash .transpose (1 , 2 )
213
- for i in range (int (N / 8 )):
214
- if i < 4 :
215
- print ("-" * 100 )
216
- print (f"out_flash[:, :, { (i * 8 )} :{ (i + 1 )* 8 } , :]:\n " )
217
- print (out_flash [:, :, (i * 8 ):(i + 1 )* 8 , :].float ())
218
- print (f"out_mma_stage1[:, :, { (i * 8 )} :{ (i + 1 )* 8 } , :]:\n " )
219
- print (out_mma_stage1 [:, :, (i * 8 ):(i + 1 )* 8 , :].float ())
220
- print ("-" * 100 )
221
- print (f"{ torch .allclose (out_flash .float (), out_mma_naive .float (), atol = 1e-2 )} " )
267
+ check_all_close (out_flash , out_mma_split_kv1 , "out_mma_split_kv1" , args .show_all )
268
+ check_all_close (out_flash , out_mma_split_q1 , "out_mma_split_q1" , args .show_all )
0 commit comments