Skip to content

Commit 4c183ee

Browse files
authored
[FA2] support shared-qkv + O s2g kernel✔️ (#191)
* Create flash_attn_mma_share_qkv_s2g_o.cu * Delete kernels/flash-attn/mma/basic/flash_attn_mma_share_qkv_s2g_o.cu * add shared-qkv + o s2g kernel
1 parent 3e539ed commit 4c183ee

File tree

4 files changed

+931
-1
lines changed

4 files changed

+931
-1
lines changed

kernels/flash-attn/flash_attn_mma.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def get_args():
4242
parser.add_argument("--warmup", "--w", type=int, default=1)
4343
parser.add_argument("--iters", "--i", type=int, default=5)
4444
parser.add_argument("--range-k", '--gk', action="store_true")
45+
parser.add_argument("--build-others", '--others', action="store_true")
4546
parser.add_argument("--tag-hints", '--tags', '--hints', type=str, default=None)
4647
return parser.parse_args()
4748

@@ -84,6 +85,8 @@ def get_build_sources():
8485
build_sources.append('./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_q.cu')
8586
build_sources.append('./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qk.cu')
8687
build_sources.append('./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu')
88+
if args.build_others:
89+
build_sources.append('./mma/others/flash_attn_mma_share_qkv_s2g_o.cu')
8790
build_sources.append('./pybind/flash_attn.cc')
8891
return build_sources
8992

@@ -127,6 +130,7 @@ def get_build_cuda_cflags(build_pkg: bool = False):
127130
extra_cuda_cflags.append("--expt-extended-lambda")
128131
extra_cuda_cflags.append("--use_fast_math")
129132
extra_cuda_cflags.append("-DFLASH_ATTN_MMA_DEBUG" if args.debug else "")
133+
extra_cuda_cflags.append("-DBUILD_FLASH_ATTN_MMA_OTHERS" if args.build_others else "")
130134
extra_cuda_cflags.append("-DBUILD_FLASH_ATTN_MMA_L20" if "L20" in device_name else "")
131135
extra_cuda_cflags.append("-DBUILD_FLASH_ATTN_MMA_4090" if "4090" in device_name else "")
132136
extra_cuda_cflags.append("-DBUILD_FLASH_ATTN_MMA_3080" if "3080" in device_name else "")
@@ -137,13 +141,21 @@ def get_build_cuda_cflags(build_pkg: bool = False):
137141
extra_cuda_cflags.append(f'-I {project_dir}/kernels/flash-attn/mma')
138142
extra_cuda_cflags.append(f'-I {project_dir}/kernels/flash-attn/mma/basic')
139143
extra_cuda_cflags.append(f'-I {project_dir}/kernels/flash-attn/mma/swizzle')
144+
extra_cuda_cflags.append(f'-I {project_dir}/kernels/flash-attn/mma/others')
140145
extra_cuda_cflags.append(f'-I {project_dir}/kernels/flash-attn/cutlass')
141146
extra_cuda_cflags.append(f'-I {project_dir}/kernels/flash-attn/pybind')
142147
extra_cuda_cflags.append(f'-I {project_dir}/third-party/cutlass/include')
143148
extra_cuda_cflags.append(f'-I {project_dir}/third-party/cutlass/tools/util/include')
144149
return extra_cuda_cflags
145150

146151

152+
def get_build_cflags():
153+
extra_cflags = []
154+
extra_cflags.append("-std=c++17")
155+
extra_cflags.append("-DBUILD_FLASH_ATTN_MMA_OTHERS" if args.build_others else "")
156+
return extra_cflags
157+
158+
147159
def pretty_print_line(m: str = "", sep: str = "-", width: int = 150):
148160
res_len = width - len(m)
149161
left_len = int(res_len / 2)
@@ -164,10 +176,15 @@ def pretty_print_line(m: str = "", sep: str = "-", width: int = 150):
164176
lib = load(name='flash_attn_lib',
165177
sources=get_build_sources(),
166178
extra_cuda_cflags=get_build_cuda_cflags(),
167-
extra_cflags=['-std=c++17'],
179+
extra_cflags=get_build_cflags(),
168180
verbose=args.verbose)
169181

170182

183+
if not args.build_others:
184+
setattr(lib, "flash_attn_mma_stages_split_q_shared_qkv_s2g_o",
185+
lambda q, k, v, o, s: o)
186+
187+
171188
def get_mha_tflops(B: int, H: int, N: int, D: int, secs: float=1.0,
172189
only_matmul: bool = False):
173190
# Q @ K^T FLOPs
@@ -235,6 +252,10 @@ def run_benchmark(perf_func: callable,
235252
hit_hints = True
236253
if not hit_hints:
237254
return None, None
255+
256+
if not args.build_others:
257+
if "s2g-o" in tag:
258+
return None, None
238259

239260
if "sdpa" in tag and (not args.run_torch_sdpa):
240261
return None, None
@@ -434,6 +455,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
434455
"mma(split-q+share-qkv+swizzle-qk+stage2)": 128,
435456
"mma(split-q+share-qkv+swizzle-qkv+stage1)": 256,
436457
"mma(split-q+share-qkv+swizzle-qkv+stage2)": 128,
458+
"mma(split-q+share-qkv+s2g-o+stage1)": 256,
459+
"mma(split-q+share-qkv+s2g-o+stage2)": 128,
437460
# Split-Q + QK Fine-grained Tiling
438461
"mma(split-q+tiling-qk+stage1)": 1024,
439462
"mma(split-q+tiling-qk+stage2)": 1024,
@@ -482,6 +505,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
482505
# Split-Q + Fully Shared QKV SMEM + Swizzle
483506
out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
484507
out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
508+
out_mma_share_qkv_s2g1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv_s2g_o, q, k, v, "mma(split-q+share-qkv+s2g-o+stage1)", o, stages=1)
509+
out_mma_share_qkv_s2g2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv_s2g_o, q, k, v, "mma(split-q+share-qkv+s2g-o+stage2)", o, stages=2)
485510
out_mma_share_qkv_sq1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv_swizzle_q, q, k, v, "mma(split-q+share-qkv+swizzle-q+stage1)", o, stages=1)
486511
out_mma_share_qkv_sq2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv_swizzle_q, q, k, v, "mma(split-q+share-qkv+swizzle-q+stage2)", o, stages=2)
487512
out_mma_share_qkv_sqk1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv_swizzle_qk, q, k, v, "mma(split-q+share-qkv+swizzle-qk+stage1)", o, stages=1)
@@ -524,6 +549,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
524549
# Split-Q + Fully Shared QKV SMEM
525550
check_all_close(out_flash, out_mma_share_qkv1, "out_mma_share_qkv1", args.check_all)
526551
check_all_close(out_flash, out_mma_share_qkv2, "out_mma_share_qkv2", args.check_all)
552+
check_all_close(out_flash, out_mma_share_qkv_s2g1, "out_mma_share_qkv_s2g1", args.check_all)
553+
check_all_close(out_flash, out_mma_share_qkv_s2g2, "out_mma_share_qkv_s2g2", args.check_all)
527554
check_all_close(out_flash, out_mma_share_qkv_sq1, "out_mma_share_qkv_sq1", args.check_all)
528555
check_all_close(out_flash, out_mma_share_qkv_sq2, "out_mma_share_qkv_sq2", args.check_all)
529556
check_all_close(out_flash, out_mma_share_qkv_sqk1, "out_mma_share_qkv_sqk1", args.check_all)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
*.so
2+
*.a
3+
*.dylib
4+
*.dll
5+
*.lib
6+
.DS_Store
7+
build
8+
*.whl
9+
tmp
10+
__pycache__
11+
*.onnx
12+
*.engine
13+
*.pt
14+
*.pth
15+
*.nsys*
16+
*.ncu*
17+
*.sqlite*
18+
*.engine
19+
*.bin
20+
outupt

0 commit comments

Comments
 (0)