|
| 1 | +import os |
| 2 | +import math |
| 3 | +import time |
| 4 | +import torch |
| 5 | +from torch.nn import functional as F |
| 6 | +from torch.utils.cpp_extension import load |
| 7 | +from typing import Optional |
| 8 | +from flash_attn import flash_attn_func |
| 9 | +import argparse |
| 10 | +import random |
| 11 | +import numpy as np |
| 12 | + |
| 13 | +torch.set_grad_enabled(False) |
| 14 | +torch.set_printoptions(precision=6, threshold=8, edgeitems=3, |
| 15 | + linewidth=120, sci_mode=False) |
| 16 | + |
| 17 | + |
| 18 | +def set_rand_seed(seed:int=1): |
| 19 | + random.seed(seed) |
| 20 | + np.random.seed(seed) |
| 21 | + torch.manual_seed(seed) |
| 22 | + torch.cuda.manual_seed_all(seed) |
| 23 | + |
| 24 | + |
| 25 | +def get_project_dir(): |
| 26 | + return os.path.dirname(os.path.dirname( |
| 27 | + os.path.dirname(os.path.abspath(__file__)))) |
| 28 | + |
| 29 | + |
| 30 | +project_dir = get_project_dir() |
| 31 | + |
| 32 | + |
| 33 | +def get_args(): |
| 34 | + parser = argparse.ArgumentParser() |
| 35 | + parser.add_argument("--no-rand-q", '--no-rq', action="store_true") |
| 36 | + parser.add_argument("--no-rand-k", '--no-rk', action="store_true") |
| 37 | + parser.add_argument("--no-rand-v", '--no-rv', action="store_true") |
| 38 | + parser.add_argument("--no-rand-qkv", '--no-rqkv', action="store_true") |
| 39 | + parser.add_argument("--naive", action="store_true") |
| 40 | + parser.add_argument("--sdpa", action="store_true") |
| 41 | + parser.add_argument("--check", action="store_true") |
| 42 | + parser.add_argument("--show-all", '--show', action="store_true") |
| 43 | + parser.add_argument("--B", type=int, default=None) |
| 44 | + parser.add_argument("--H", type=int, default=None) |
| 45 | + parser.add_argument("--N", type=int, default=None) |
| 46 | + parser.add_argument("--D", type=int, default=None) |
| 47 | + parser.add_argument("--seed", type=int, default=None) |
| 48 | + parser.add_argument("--debug", action="store_true") |
| 49 | + parser.add_argument("--warmup", type=int, default=2) |
| 50 | + parser.add_argument("--iters", type=int, default=10) |
| 51 | + parser.add_argument("--range-k", '--gk', action="store_true") |
| 52 | + return parser.parse_args() |
| 53 | + |
| 54 | + |
| 55 | +args = get_args() |
| 56 | +print(args) |
| 57 | + |
| 58 | + |
| 59 | +# Load the CUDA kernel as a python module |
| 60 | +lib = load(name='flash_attn_lib', |
| 61 | + sources=[ |
| 62 | + './naive/flash_attn_cuda.cu', |
| 63 | + './mma/flash_attn_mma_naive.cu', |
| 64 | + './mma/flash_attn_mma_stage.cu', |
| 65 | + './pybind/flash_attn.cc'], |
| 66 | + extra_cuda_cflags=[ |
| 67 | + "-O3", |
| 68 | + "-U__CUDA_NO_HALF_OPERATORS__", |
| 69 | + "-U__CUDA_NO_HALF_CONVERSIONS__", |
| 70 | + "-U__CUDA_NO_HALF2_OPERATORS__", |
| 71 | + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", |
| 72 | + "--expt-relaxed-constexpr", |
| 73 | + "--expt-extended-lambda", |
| 74 | + "--use_fast_math", |
| 75 | + f"-I {project_dir}/kernels/flash-attn/utils", |
| 76 | + "-DFLASH_ATTN_MMA_DEBUG" if args.debug else "" |
| 77 | + ], |
| 78 | + extra_cflags=['-std=c++17']) |
| 79 | + |
| 80 | + |
| 81 | +def run_benchmark(perf_func: callable, |
| 82 | + q: torch.Tensor, |
| 83 | + k: torch.Tensor, |
| 84 | + v: torch.Tensor, |
| 85 | + tag: str, |
| 86 | + out: Optional[torch.Tensor] = None, |
| 87 | + s: Optional[torch.Tensor] = None, # BUDEG |
| 88 | + stages: int = -1, |
| 89 | + warmup: int = args.warmup, |
| 90 | + iters: int = args.iters, |
| 91 | + show_all: bool = args.show_all): |
| 92 | + if out is not None: |
| 93 | + out.fill_(0) |
| 94 | + if s is not None: |
| 95 | + s.fill_(0) |
| 96 | + if out is not None: |
| 97 | + for i in range(warmup): |
| 98 | + if stages >= 1: |
| 99 | + if s is not None: |
| 100 | + perf_func(q, k, v, out, s, stages) |
| 101 | + else: |
| 102 | + perf_func(q, k, v, out, stages) |
| 103 | + else: |
| 104 | + perf_func(q, k, v, out) |
| 105 | + else: |
| 106 | + for i in range(warmup): |
| 107 | + _ = perf_func(q, k, v) |
| 108 | + |
| 109 | + torch.cuda.synchronize() |
| 110 | + start = time.time() |
| 111 | + # iters |
| 112 | + if out is not None: |
| 113 | + for i in range(iters): |
| 114 | + if stages >= 1: |
| 115 | + if s is not None: |
| 116 | + perf_func(q, k, v, out, s, stages) |
| 117 | + else: |
| 118 | + perf_func(q, k, v, out, stages) |
| 119 | + else: |
| 120 | + perf_func(q, k, v, out) |
| 121 | + else: |
| 122 | + for i in range(iters): |
| 123 | + out = perf_func(q, k, v) |
| 124 | + torch.cuda.synchronize() |
| 125 | + end = time.time() |
| 126 | + total_time = (end - start) * 1000 # ms |
| 127 | + mean_time = total_time / iters |
| 128 | + out_info = f"{tag}" |
| 129 | + out_val_first = out.flatten()[:3].detach().cpu().numpy().tolist() |
| 130 | + out_val_last = out.flatten()[-3:].detach().cpu().numpy().tolist() |
| 131 | + out_val_first = [round(v, 8) for v in out_val_first] |
| 132 | + out_val_last = [round(v, 8) for v in out_val_last] |
| 133 | + out_val = out_val_first[:2] |
| 134 | + out_val.append(out_val_last[-1]) |
| 135 | + out_val = [f"{v:<12}" for v in out_val] |
| 136 | + print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms") |
| 137 | + if show_all: |
| 138 | + print(out) |
| 139 | + time.sleep(0.05) |
| 140 | + return out.clone(), mean_time |
| 141 | + |
| 142 | + |
| 143 | +def get_qkvo(B, H, N, D): |
| 144 | + if not (args.no_rand_q or args.no_rand_qkv): |
| 145 | + q = torch.randn((B, H, N, D), dtype=torch.half, device="cuda") |
| 146 | + else: |
| 147 | + q = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() |
| 148 | + if not (args.no_rand_k or args.no_rand_qkv): |
| 149 | + k = torch.randn((B, H, N, D), dtype=torch.half, device="cuda") |
| 150 | + else: |
| 151 | + k = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() |
| 152 | + if args.range_k: |
| 153 | + for i in range(N): |
| 154 | + k[:, :, i, :] = (i + 1) / N |
| 155 | + k = k.cuda().half().contiguous() |
| 156 | + if not (args.no_rand_v or args.no_rand_qkv): |
| 157 | + v = torch.randn((B, H, N, D), dtype=torch.half, device="cuda") |
| 158 | + else: |
| 159 | + v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() |
| 160 | + |
| 161 | + o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous() |
| 162 | + |
| 163 | + return q, k, v, o |
| 164 | + |
| 165 | + |
| 166 | +# un-fused naive attn |
| 167 | +def naive_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): |
| 168 | + att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) |
| 169 | + att = F.softmax(att, dim=-1) |
| 170 | + y = att @ v |
| 171 | + return y |
| 172 | + |
| 173 | + |
| 174 | +Bs = [1, 2, 4] if not args.B else [args.B] |
| 175 | +Hs = [1, 4, 8] if not args.H else [args.H] |
| 176 | +Ns = [1024, 2048] if not args.N else [args.N] |
| 177 | +Ds = [64, 128] if not args.D else [args.D] |
| 178 | +# batch_size, n_head, seq_len, head_dim (B,H,N,D) |
| 179 | +BHNDs = [(B, H, N, D) for B in Bs for H in Hs for N in Ns for D in Ds] |
| 180 | + |
| 181 | +seed = args.seed if args.seed else random.choice(range(10000)) |
| 182 | +set_rand_seed(seed) |
| 183 | +print("-" * 100) |
| 184 | +print(" "* 10 + f"B: batch_size, H: n_head, N: seq_len, D: head_dim, " |
| 185 | + f"seed: {seed}, Warmup: {args.warmup}, Iters: {args.iters}") |
| 186 | + |
| 187 | +for (B, H, N, D) in BHNDs: |
| 188 | + print("-" * 100) |
| 189 | + print(" " * 25 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}") |
| 190 | + q, k, v, o = get_qkvo(B, H, N, D) |
| 191 | + tk = k.transpose(-2, -1).contiguous() |
| 192 | + fq = q.transpose(1, 2).contiguous() |
| 193 | + fk = k.transpose(1, 2).contiguous() |
| 194 | + fv = v.transpose(1, 2).contiguous() |
| 195 | + torch.cuda.synchronize() |
| 196 | + |
| 197 | + if args.naive: |
| 198 | + out_naive, _ = run_benchmark(naive_attn, q, k, v, "naive(unfused)") |
| 199 | + |
| 200 | + # using fp16 Tesor Core MMA instruction |
| 201 | + out_mma_naive, _ = run_benchmark(lib.flash_attn_mma_naive, q, k, v, "mma(naive)", o) |
| 202 | + out_mma_stage1, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage1)", o, stages=1) |
| 203 | + out_mma_stage2, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage2)", o, stages=2) |
| 204 | + out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)") |
| 205 | + |
| 206 | + if args.sdpa: |
| 207 | + out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)") |
| 208 | + print("-" * 100) |
| 209 | + |
| 210 | + torch.cuda.synchronize() |
| 211 | + if args.check: |
| 212 | + out_flash = out_flash.transpose(1, 2) |
| 213 | + for i in range(int(N/8)): |
| 214 | + if i < 4: |
| 215 | + print("-" * 100) |
| 216 | + print(f"out_flash[:, :, {(i*8)}:{(i+1)*8}, :]:\n") |
| 217 | + print(out_flash[:, :, (i*8):(i+1)*8, :].float()) |
| 218 | + print(f"out_mma_stage1[:, :, {(i*8)}:{(i+1)*8}, :]:\n") |
| 219 | + print(out_mma_stage1[:, :, (i*8):(i+1)*8, :].float()) |
| 220 | + print("-" * 100) |
| 221 | + print(f"{torch.allclose(out_flash.float(), out_mma_naive.float(), atol=1e-2)}") |
0 commit comments