@@ -42,6 +42,7 @@ def get_args():
42
42
parser .add_argument ("--warmup" , "--w" , type = int , default = 1 )
43
43
parser .add_argument ("--iters" , "--i" , type = int , default = 5 )
44
44
parser .add_argument ("--range-k" , '--gk' , action = "store_true" )
45
+ parser .add_argument ("--build-others" , '--others' , action = "store_true" )
45
46
parser .add_argument ("--tag-hints" , '--tags' , '--hints' , type = str , default = None )
46
47
return parser .parse_args ()
47
48
@@ -84,6 +85,8 @@ def get_build_sources():
84
85
build_sources .append ('./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_q.cu' )
85
86
build_sources .append ('./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qk.cu' )
86
87
build_sources .append ('./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu' )
88
+ if args .build_others :
89
+ build_sources .append ('./mma/others/flash_attn_mma_share_qkv_s2g_o.cu' )
87
90
build_sources .append ('./pybind/flash_attn.cc' )
88
91
return build_sources
89
92
@@ -127,6 +130,7 @@ def get_build_cuda_cflags(build_pkg: bool = False):
127
130
extra_cuda_cflags .append ("--expt-extended-lambda" )
128
131
extra_cuda_cflags .append ("--use_fast_math" )
129
132
extra_cuda_cflags .append ("-DFLASH_ATTN_MMA_DEBUG" if args .debug else "" )
133
+ extra_cuda_cflags .append ("-DBUILD_FLASH_ATTN_MMA_OTHERS" if args .build_others else "" )
130
134
extra_cuda_cflags .append ("-DBUILD_FLASH_ATTN_MMA_L20" if "L20" in device_name else "" )
131
135
extra_cuda_cflags .append ("-DBUILD_FLASH_ATTN_MMA_4090" if "4090" in device_name else "" )
132
136
extra_cuda_cflags .append ("-DBUILD_FLASH_ATTN_MMA_3080" if "3080" in device_name else "" )
@@ -137,13 +141,21 @@ def get_build_cuda_cflags(build_pkg: bool = False):
137
141
extra_cuda_cflags .append (f'-I { project_dir } /kernels/flash-attn/mma' )
138
142
extra_cuda_cflags .append (f'-I { project_dir } /kernels/flash-attn/mma/basic' )
139
143
extra_cuda_cflags .append (f'-I { project_dir } /kernels/flash-attn/mma/swizzle' )
144
+ extra_cuda_cflags .append (f'-I { project_dir } /kernels/flash-attn/mma/others' )
140
145
extra_cuda_cflags .append (f'-I { project_dir } /kernels/flash-attn/cutlass' )
141
146
extra_cuda_cflags .append (f'-I { project_dir } /kernels/flash-attn/pybind' )
142
147
extra_cuda_cflags .append (f'-I { project_dir } /third-party/cutlass/include' )
143
148
extra_cuda_cflags .append (f'-I { project_dir } /third-party/cutlass/tools/util/include' )
144
149
return extra_cuda_cflags
145
150
146
151
152
+ def get_build_cflags ():
153
+ extra_cflags = []
154
+ extra_cflags .append ("-std=c++17" )
155
+ extra_cflags .append ("-DBUILD_FLASH_ATTN_MMA_OTHERS" if args .build_others else "" )
156
+ return extra_cflags
157
+
158
+
147
159
def pretty_print_line (m : str = "" , sep : str = "-" , width : int = 150 ):
148
160
res_len = width - len (m )
149
161
left_len = int (res_len / 2 )
@@ -164,10 +176,15 @@ def pretty_print_line(m: str = "", sep: str = "-", width: int = 150):
164
176
lib = load (name = 'flash_attn_lib' ,
165
177
sources = get_build_sources (),
166
178
extra_cuda_cflags = get_build_cuda_cflags (),
167
- extra_cflags = [ '-std=c++17' ] ,
179
+ extra_cflags = get_build_cflags () ,
168
180
verbose = args .verbose )
169
181
170
182
183
+ if not args .build_others :
184
+ setattr (lib , "flash_attn_mma_stages_split_q_shared_qkv_s2g_o" ,
185
+ lambda q , k , v , o , s : o )
186
+
187
+
171
188
def get_mha_tflops (B : int , H : int , N : int , D : int , secs : float = 1.0 ,
172
189
only_matmul : bool = False ):
173
190
# Q @ K^T FLOPs
@@ -235,6 +252,10 @@ def run_benchmark(perf_func: callable,
235
252
hit_hints = True
236
253
if not hit_hints :
237
254
return None , None
255
+
256
+ if not args .build_others :
257
+ if "s2g-o" in tag :
258
+ return None , None
238
259
239
260
if "sdpa" in tag and (not args .run_torch_sdpa ):
240
261
return None , None
@@ -434,6 +455,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
434
455
"mma(split-q+share-qkv+swizzle-qk+stage2)" : 128 ,
435
456
"mma(split-q+share-qkv+swizzle-qkv+stage1)" : 256 ,
436
457
"mma(split-q+share-qkv+swizzle-qkv+stage2)" : 128 ,
458
+ "mma(split-q+share-qkv+s2g-o+stage1)" : 256 ,
459
+ "mma(split-q+share-qkv+s2g-o+stage2)" : 128 ,
437
460
# Split-Q + QK Fine-grained Tiling
438
461
"mma(split-q+tiling-qk+stage1)" : 1024 ,
439
462
"mma(split-q+tiling-qk+stage2)" : 1024 ,
@@ -482,6 +505,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
482
505
# Split-Q + Fully Shared QKV SMEM + Swizzle
483
506
out_mma_share_qkv1 , _ = run_benchmark (lib .flash_attn_mma_stages_split_q_shared_qkv , q , k , v , "mma(split-q+share-qkv+stage1)" , o , stages = 1 )
484
507
out_mma_share_qkv2 , _ = run_benchmark (lib .flash_attn_mma_stages_split_q_shared_qkv , q , k , v , "mma(split-q+share-qkv+stage2)" , o , stages = 2 )
508
+ out_mma_share_qkv_s2g1 , _ = run_benchmark (lib .flash_attn_mma_stages_split_q_shared_qkv_s2g_o , q , k , v , "mma(split-q+share-qkv+s2g-o+stage1)" , o , stages = 1 )
509
+ out_mma_share_qkv_s2g2 , _ = run_benchmark (lib .flash_attn_mma_stages_split_q_shared_qkv_s2g_o , q , k , v , "mma(split-q+share-qkv+s2g-o+stage2)" , o , stages = 2 )
485
510
out_mma_share_qkv_sq1 , _ = run_benchmark (lib .flash_attn_mma_stages_split_q_shared_qkv_swizzle_q , q , k , v , "mma(split-q+share-qkv+swizzle-q+stage1)" , o , stages = 1 )
486
511
out_mma_share_qkv_sq2 , _ = run_benchmark (lib .flash_attn_mma_stages_split_q_shared_qkv_swizzle_q , q , k , v , "mma(split-q+share-qkv+swizzle-q+stage2)" , o , stages = 2 )
487
512
out_mma_share_qkv_sqk1 , _ = run_benchmark (lib .flash_attn_mma_stages_split_q_shared_qkv_swizzle_qk , q , k , v , "mma(split-q+share-qkv+swizzle-qk+stage1)" , o , stages = 1 )
@@ -524,6 +549,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
524
549
# Split-Q + Fully Shared QKV SMEM
525
550
check_all_close (out_flash , out_mma_share_qkv1 , "out_mma_share_qkv1" , args .check_all )
526
551
check_all_close (out_flash , out_mma_share_qkv2 , "out_mma_share_qkv2" , args .check_all )
552
+ check_all_close (out_flash , out_mma_share_qkv_s2g1 , "out_mma_share_qkv_s2g1" , args .check_all )
553
+ check_all_close (out_flash , out_mma_share_qkv_s2g2 , "out_mma_share_qkv_s2g2" , args .check_all )
527
554
check_all_close (out_flash , out_mma_share_qkv_sq1 , "out_mma_share_qkv_sq1" , args .check_all )
528
555
check_all_close (out_flash , out_mma_share_qkv_sq2 , "out_mma_share_qkv_sq2" , args .check_all )
529
556
check_all_close (out_flash , out_mma_share_qkv_sqk1 , "out_mma_share_qkv_sqk1" , args .check_all )
0 commit comments