1
1
import torch
2
2
import torch .nn .functional as F
3
3
import tilelang
4
- from tilelang .autotuner import *
5
4
import tilelang .language as T
6
5
from einops import rearrange , einsum
7
6
import argparse
8
7
9
8
tilelang .disable_cache ()
10
9
11
10
11
+ def get_configs ():
12
+ import itertools
13
+ BLOCK_N = [16 , 32 , 64 , 128 ]
14
+ BLOCK_H = [16 , 32 , 64 , 128 ]
15
+ num_split = [1 , 2 , 4 , 8 , 16 , 32 ]
16
+ threads = [128 , 256 ]
17
+
18
+ _configs = list (itertools .product (BLOCK_N , BLOCK_H , num_split , threads ))
19
+
20
+ return [{
21
+ "block_N" : c [0 ],
22
+ "block_H" : c [1 ],
23
+ "num_split" : c [2 ],
24
+ "threads" : c [3 ],
25
+ } for c in _configs ]
26
+
27
+
28
+ @tilelang .autotune (configs = get_configs ())
12
29
@tilelang .jit (
13
30
out_idx = [6 ], pass_configs = {
14
31
tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True ,
@@ -273,26 +290,39 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
273
290
274
291
if __name__ == "__main__" :
275
292
parser = argparse .ArgumentParser ()
276
- parser .add_argument ('--batch' , type = int , default = 1 , help = 'batch size' )
293
+ parser .add_argument ('--batch' , type = int , default = 128 , help = 'batch size' )
277
294
parser .add_argument ('--heads' , type = int , default = 128 , help = 'q heads number' )
278
295
parser .add_argument ('--kv_heads' , type = int , default = 1 , help = 'kv heads number' )
279
- parser .add_argument ('--kv_ctx' , type = int , default = 1024 , help = 'kv context length' )
296
+ parser .add_argument ('--kv_ctx' , type = int , default = 8192 , help = 'kv context length' )
280
297
parser .add_argument ('--dim' , type = int , default = 512 , help = 'head dim' )
281
298
parser .add_argument ('--pe_dim' , type = int , default = 64 , help = 'pe head dim' )
282
- parser .add_argument ('--auto_tune ' , action = 'store_true' , help = 'auto tune' )
299
+ parser .add_argument ('--autotune ' , action = 'store_true' , help = 'auto tune' )
283
300
args = parser .parse_args ()
284
301
batch , heads , kv_heads , kv_ctx , dim , pe_dim = args .batch , args .heads , args .kv_heads , args .kv_ctx , args .dim , args .pe_dim
285
- enable_autotune = args .auto_tune
302
+ enable_autotune = args .autotune
286
303
287
304
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim )
288
305
pv_flops = 2 * batch * heads * kv_ctx * dim
289
306
total_flops = qk_flops + pv_flops
290
307
BLOCK_N = 32
291
308
BLOCK_H = 64
292
309
num_split = 4
310
+ threads = 128
293
311
294
- kernel = flashmla_decode (batch , heads , kv_heads , kv_ctx , dim , pe_dim , BLOCK_N , BLOCK_H ,
295
- num_split )
312
+ if enable_autotune :
313
+ kernel = flashmla_decode (batch , heads , kv_heads , kv_ctx , dim , pe_dim )
314
+ else :
315
+ kernel = flashmla_decode (
316
+ batch ,
317
+ heads ,
318
+ kv_heads ,
319
+ kv_ctx ,
320
+ dim ,
321
+ pe_dim ,
322
+ BLOCK_N ,
323
+ BLOCK_H ,
324
+ num_split ,
325
+ threads = threads )
296
326
profiler = kernel .get_profiler (tensor_supply_type = tilelang .TensorSupplyType .Randn )
297
327
input_tensors = profiler ._get_inputs ()
298
328
tilelang_output = kernel (* input_tensors )
@@ -303,35 +333,3 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
303
333
latency = profiler .do_bench (warmup = 500 )
304
334
print (f"Latency: { latency } ms" )
305
335
print (f"TFlops: { total_flops / latency * 1e-9 } TFlops" )
306
-
307
- # Enable Auto Tuning
308
-
309
-
310
- def get_configs ():
311
- import itertools
312
- BLOCK_N = [16 , 32 , 64 , 128 ]
313
- BLOCK_H = [16 , 32 , 64 , 128 ]
314
- num_split = [1 , 2 , 4 , 8 , 16 , 32 ]
315
- thread_num = [128 , 256 ]
316
-
317
- _configs = list (itertools .product (BLOCK_N , BLOCK_H , num_split , thread_num ))
318
-
319
- return [{
320
- "block_N" : c [0 ],
321
- "block_H" : c [1 ],
322
- "num_split" : c [2 ],
323
- "thread_num" : c [3 ],
324
- } for c in _configs ]
325
-
326
- def wrapped_kernel (block_N = None , block_H = None , num_split = None , thread_num = None ):
327
- return flashmla_decode (batch , heads , kv_heads , kv_ctx , dim , pe_dim , block_N , block_H ,
328
- num_split , thread_num )
329
-
330
- if enable_autotune :
331
- autotuner = AutoTuner .from_kernel (kernel = wrapped_kernel , configs = get_configs ())
332
- tune_result = autotuner .run (warmup = 3 , rep = 20 )
333
- best_latency = tune_result .latency
334
- best_config = tune_result .config
335
- print (f"Best latency: { best_latency } ms" )
336
- print (f"Best TFlops: { total_flops / best_latency * 1e-9 } TFlops" )
337
- print (f"Best config: { best_config } " )
0 commit comments