|
| 1 | +import torch |
| 2 | +import time |
| 3 | +from torch.utils.cpp_extension import load |
| 4 | +from functools import partial |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +torch.set_grad_enabled(False) |
| 8 | + |
| 9 | +# # Load the CUDA kernel as a python module |
| 10 | +# lib = load(name='hgemm_lib', |
| 11 | +# sources=['hgemm.cu'], |
| 12 | +# extra_cuda_cflags=[ |
| 13 | +# "-O3", |
| 14 | +# "-U__CUDA_NO_HALF_OPERATORS__", |
| 15 | +# "-U__CUDA_NO_HALF_CONVERSIONS__", |
| 16 | +# "-U__CUDA_NO_HALF2_OPERATORS__", |
| 17 | +# "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", |
| 18 | +# "--expt-relaxed-constexpr", |
| 19 | +# "--expt-extended-lambda", |
| 20 | +# "--use_fast_math" |
| 21 | +# ], |
| 22 | +# extra_cflags=['-std=c++17']) |
| 23 | + |
| 24 | + |
| 25 | +def run_benchmark(perf_func: callable, |
| 26 | + a: torch.Tensor, b: torch.Tensor, |
| 27 | + tag: str, out: Optional[torch.Tensor] = None, |
| 28 | + warmup: int = 1, iters: int = 10, |
| 29 | + show_all: bool = False): |
| 30 | + if out is not None: |
| 31 | + out.fill_(0) |
| 32 | + if out is not None: |
| 33 | + for i in range(warmup): |
| 34 | + perf_func(a, b, out) |
| 35 | + else: |
| 36 | + for i in range(warmup): |
| 37 | + _ = perf_func(a, b) |
| 38 | + |
| 39 | + torch.cuda.synchronize() |
| 40 | + start = time.time() |
| 41 | + # iters |
| 42 | + if out is not None: |
| 43 | + for i in range(iters): |
| 44 | + perf_func(a, b, out) |
| 45 | + else: |
| 46 | + for i in range(iters): |
| 47 | + out = perf_func(a, b) |
| 48 | + torch.cuda.synchronize() |
| 49 | + end = time.time() |
| 50 | + total_time = (end - start) * 1000 # ms |
| 51 | + mean_time = total_time / iters |
| 52 | + out_info = f"out_{tag}" |
| 53 | + out_val = out.flatten().detach().cpu().numpy().tolist()[:3] |
| 54 | + out_val = [round(v, 8) for v in out_val] |
| 55 | + out_val = [f"{v:<12}" for v in out_val] |
| 56 | + print(f"{out_info:>32}: {out_val}, time:{mean_time:.6f}ms") |
| 57 | + if show_all: print(out) |
| 58 | + return out.clone(), mean_time |
| 59 | + |
| 60 | + |
| 61 | +# Ms = [1024, 2048, 4096] |
| 62 | +# Ns = [1024, 2048, 4096] |
| 63 | +# Ks = [256, 512, 1024] |
| 64 | +Ms = [1024] |
| 65 | +Ns = [1024] |
| 66 | +Ks = [256] |
| 67 | +MNKs = [(M, N, K) for M in Ms for N in Ns for K in Ks] |
| 68 | +for (M, N, K) in MNKs: |
| 69 | + print("-" * 110) |
| 70 | + print(" " * 45 + f"M={M}, N={N}, K={K}") |
| 71 | + a = torch.randn((M, K)).cuda().half().contiguous() |
| 72 | + b = torch.randn((K, N)).cuda().half().contiguous() |
| 73 | + c = torch.randn((M, N)).cuda().half().contiguous() |
| 74 | + # run_benchmark(lib.hgemm_naive_f16, a, b, "f16", c) |
| 75 | + # run_benchmark(lib.hgemm_sliced_k_f16, a, b, "f16(sk)", c) |
| 76 | + # run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(t4x4bcf)", c) |
| 77 | + # run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(t4x4offset)", c) |
| 78 | + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4, a, b, "f16x4(t8x8sk)", c) |
| 79 | + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_bcf, a, b, "f16x4(t8x8bcf)", c) |
| 80 | + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack, a, b, "f16x4pack(t8x8sk)", c) |
| 81 | + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(bcf)", c) |
| 82 | + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(bcf+offset)", c) |
| 83 | + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(bcf)", c) |
| 84 | + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_offset, a, b, "f16x8pack(bcf+offset)", c) |
| 85 | + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "f16x8pack(dbuf)", c) |
| 86 | + run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th") |
| 87 | + print("-" * 110) |
| 88 | + |
0 commit comments