|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import sys |
| 4 | +from typing import Optional |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.nn.functional as F |
| 8 | + |
| 9 | +from vllm._C import ops |
| 10 | +from vllm.model_executor.layers.quantization.aqlm import ( |
| 11 | + dequantize_weight, generic_dequantize_gemm, get_int_dtype, |
| 12 | + optimized_dequantize_gemm) |
| 13 | + |
| 14 | +os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
| 15 | + |
| 16 | + |
| 17 | +def torch_mult( |
| 18 | + input: torch.Tensor, # [..., in_features] |
| 19 | + weights: torch.Tensor, |
| 20 | + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] |
| 21 | +) -> torch.Tensor: |
| 22 | + output = F.linear(input, weights) |
| 23 | + return output |
| 24 | + |
| 25 | + |
| 26 | +def dequant_out_scale( |
| 27 | + input: torch.Tensor, # [..., in_features] |
| 28 | + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] |
| 29 | + codebooks: torch. |
| 30 | + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] |
| 31 | + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] |
| 32 | + output_partition_sizes: torch.IntTensor, |
| 33 | + bias: Optional[torch.Tensor], |
| 34 | +) -> torch.Tensor: |
| 35 | + |
| 36 | + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) |
| 37 | + |
| 38 | + if bias is None: |
| 39 | + output = F.linear(input, weights, bias) |
| 40 | + orig_shape = output.shape |
| 41 | + flattened_output = output.view(-1, output.size(-1)) |
| 42 | + f_scales = scales.view(-1, scales.shape[0]) |
| 43 | + b_scales = f_scales.expand(flattened_output.shape[0], -1) |
| 44 | + flattened_output *= b_scales |
| 45 | + return flattened_output.view(orig_shape) |
| 46 | + else: |
| 47 | + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( |
| 48 | + -1, weights.shape[1]) |
| 49 | + weights *= b_scales |
| 50 | + return F.linear(input, weights, bias) |
| 51 | + |
| 52 | + |
| 53 | +def dequant_weight_scale( |
| 54 | + input: torch.Tensor, # [..., in_features] |
| 55 | + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] |
| 56 | + codebooks: torch. |
| 57 | + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] |
| 58 | + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] |
| 59 | + output_partition_sizes: torch.IntTensor, |
| 60 | + bias: Optional[torch.Tensor], |
| 61 | +) -> torch.Tensor: |
| 62 | + |
| 63 | + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) |
| 64 | + |
| 65 | + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( |
| 66 | + -1, weights.shape[1]) |
| 67 | + weights *= b_scales |
| 68 | + return F.linear(input, weights, bias) |
| 69 | + |
| 70 | + |
| 71 | +def dequant_no_scale( |
| 72 | + input: torch.Tensor, # [..., in_features] |
| 73 | + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] |
| 74 | + codebooks: torch. |
| 75 | + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] |
| 76 | + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] |
| 77 | + output_partition_sizes: torch.IntTensor, |
| 78 | + bias: Optional[torch.Tensor], |
| 79 | +) -> torch.Tensor: |
| 80 | + |
| 81 | + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) |
| 82 | + |
| 83 | + return F.linear(input, weights, bias) |
| 84 | + |
| 85 | + |
| 86 | +# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against |
| 87 | +# the generic pytorch version. |
| 88 | +# Just visual comparison. |
| 89 | +def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: |
| 90 | + |
| 91 | + n = parts.sum().item() |
| 92 | + |
| 93 | + device = torch.device('cuda:0') |
| 94 | + |
| 95 | + code_range = (1 << bits) // 2 |
| 96 | + ingroups = 8 |
| 97 | + |
| 98 | + codes = torch.randint(-code_range, |
| 99 | + code_range, |
| 100 | + size=(n, k // ingroups, nbooks), |
| 101 | + dtype=get_int_dtype(bits), |
| 102 | + device=device) |
| 103 | + |
| 104 | + codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), |
| 105 | + dtype=torch.float16, |
| 106 | + device=device) |
| 107 | + |
| 108 | + count = 0 |
| 109 | + for index in range(16): |
| 110 | + for i in range(8): |
| 111 | + for book in range(nbooks): |
| 112 | + codebooks[book, index, 0, i] = count * (10**book) |
| 113 | + count += 1 |
| 114 | + |
| 115 | + print("codes shape", codes.shape) |
| 116 | + |
| 117 | + for i in range(16): |
| 118 | + for book in range(nbooks): |
| 119 | + codes[0, i, book] = i |
| 120 | + codes[0, -i, book] = i |
| 121 | + |
| 122 | + weights = dequantize_weight(codes, codebooks, None) |
| 123 | + weights2 = ops.aqlm_dequant(codes, codebooks, parts) |
| 124 | + |
| 125 | + print("weights shape:", weights.shape) |
| 126 | + print("weights2 shape:", weights2.shape) |
| 127 | + |
| 128 | + print("weights are:", weights) |
| 129 | + print("weights2 are:", weights2) |
| 130 | + |
| 131 | + print("first 128 weights are", weights[0, 0:128].to(torch.int32)) |
| 132 | + print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32)) |
| 133 | + |
| 134 | + print("last 128 weights are", weights[0, -128:]) |
| 135 | + print("last 128 weights2 are:", weights2[0, -128:]) |
| 136 | + |
| 137 | + |
| 138 | +def main(): |
| 139 | + |
| 140 | + parser = argparse.ArgumentParser(description="Benchmark aqlm performance.") |
| 141 | + |
| 142 | + # Add arguments |
| 143 | + parser.add_argument("--nbooks", |
| 144 | + type=int, |
| 145 | + default=1, |
| 146 | + help="Number of codebooks (default: 1)") |
| 147 | + parser.add_argument("--bits", |
| 148 | + type=int, |
| 149 | + default=16, |
| 150 | + help="Number of bits per code element (default: 16)") |
| 151 | + parser.add_argument( |
| 152 | + "--test", |
| 153 | + type=bool, |
| 154 | + default=False, |
| 155 | + help="Run the decompression/dequant tester rather than benchmarking " |
| 156 | + "(default: False)") |
| 157 | + |
| 158 | + # Parse the arguments |
| 159 | + args = parser.parse_args() |
| 160 | + |
| 161 | + # Extract values |
| 162 | + nbooks = args.nbooks |
| 163 | + bits = args.bits |
| 164 | + |
| 165 | + if args.test: |
| 166 | + dequant_test(4096, torch.tensor((4096, )), nbooks, bits) |
| 167 | + return |
| 168 | + |
| 169 | + # Otherwise, benchmark. |
| 170 | + methods = [ |
| 171 | + ops.aqlm_gemm, |
| 172 | + dequant_out_scale, |
| 173 | + generic_dequantize_gemm, |
| 174 | + optimized_dequantize_gemm, |
| 175 | + dequant_weight_scale, |
| 176 | + torch_mult, |
| 177 | + dequant_no_scale, |
| 178 | + ] |
| 179 | + |
| 180 | + filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv" |
| 181 | + print(f"writing benchmarks to file {filename}") |
| 182 | + with open(filename, "w") as f: |
| 183 | + sys.stdout = f |
| 184 | + |
| 185 | + print('m | k | n | n parts', end='') |
| 186 | + for method in methods: |
| 187 | + print(f" | {method.__name__.replace('_', ' ')} (µs)", end='') |
| 188 | + print('') |
| 189 | + |
| 190 | + # These are reasonable prefill sizes. |
| 191 | + ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )), |
| 192 | + (4096, (11008, 11008)), (11008, (4096, ))) |
| 193 | + |
| 194 | + # reasonable ranges for m. |
| 195 | + for m in [ |
| 196 | + 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112, |
| 197 | + 128, 256, 512, 1024, 1536, 2048, 3072, 4096 |
| 198 | + ]: |
| 199 | + print(f'{m}', file=sys.__stdout__) |
| 200 | + for ksp in ksandpartions: |
| 201 | + run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, |
| 202 | + methods) |
| 203 | + |
| 204 | + sys.stdout = sys.__stdout__ |
| 205 | + |
| 206 | + |
| 207 | +def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, |
| 208 | + methods): |
| 209 | + |
| 210 | + # I didn't see visible improvements from increasing these, but feel free :) |
| 211 | + num_warmup_trials = 1 |
| 212 | + num_trials = 1 |
| 213 | + |
| 214 | + num_calls = 100 |
| 215 | + |
| 216 | + # warmup. |
| 217 | + for method in methods: |
| 218 | + for _ in range(num_warmup_trials): |
| 219 | + run_timing( |
| 220 | + num_calls=num_calls, |
| 221 | + m=m, |
| 222 | + k=k, |
| 223 | + parts=parts, |
| 224 | + nbooks=nbooks, |
| 225 | + bits=bits, |
| 226 | + method=method, |
| 227 | + ) |
| 228 | + |
| 229 | + n = parts.sum().item() |
| 230 | + print(f'{m} | {k} | {n} | {parts.tolist()}', end='') |
| 231 | + |
| 232 | + for method in methods: |
| 233 | + best_time_us = 1e20 |
| 234 | + for _ in range(num_trials): |
| 235 | + kernel_dur_ms = run_timing( |
| 236 | + num_calls=num_calls, |
| 237 | + m=m, |
| 238 | + k=k, |
| 239 | + parts=parts, |
| 240 | + nbooks=nbooks, |
| 241 | + bits=bits, |
| 242 | + method=method, |
| 243 | + ) |
| 244 | + |
| 245 | + kernel_dur_us = 1000 * kernel_dur_ms |
| 246 | + |
| 247 | + if kernel_dur_us < best_time_us: |
| 248 | + best_time_us = kernel_dur_us |
| 249 | + |
| 250 | + print(f' | {kernel_dur_us:.0f}', end='') |
| 251 | + |
| 252 | + print('') |
| 253 | + |
| 254 | + |
| 255 | +def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor, |
| 256 | + nbooks: int, bits: int, method) -> float: |
| 257 | + |
| 258 | + n = parts.sum().item() |
| 259 | + |
| 260 | + device = torch.device('cuda:0') |
| 261 | + |
| 262 | + input = torch.randn((1, m, k), dtype=torch.float16, device=device) |
| 263 | + |
| 264 | + code_range = (1 << bits) // 2 |
| 265 | + ingroups = 8 |
| 266 | + |
| 267 | + codes = torch.randint(-code_range, |
| 268 | + code_range, |
| 269 | + size=(n, k // ingroups, nbooks), |
| 270 | + dtype=get_int_dtype(bits), |
| 271 | + device=device) |
| 272 | + |
| 273 | + codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), |
| 274 | + dtype=torch.float16, |
| 275 | + device=device) |
| 276 | + |
| 277 | + scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) |
| 278 | + |
| 279 | + # for comparison to just a pytorch mult. |
| 280 | + weights = torch.randn((n, k), dtype=torch.float16, device=device) |
| 281 | + |
| 282 | + start_event = torch.cuda.Event(enable_timing=True) |
| 283 | + end_event = torch.cuda.Event(enable_timing=True) |
| 284 | + |
| 285 | + start_event.record() |
| 286 | + |
| 287 | + if method is torch_mult: |
| 288 | + for i in range(num_calls): |
| 289 | + torch_mult(input, weights, scales) |
| 290 | + else: |
| 291 | + for i in range(num_calls): |
| 292 | + method(input, codes, codebooks, scales, parts, None) |
| 293 | + |
| 294 | + end_event.record() |
| 295 | + end_event.synchronize() |
| 296 | + |
| 297 | + dur_ms = start_event.elapsed_time(end_event) / num_calls |
| 298 | + return dur_ms |
| 299 | + |
| 300 | + |
| 301 | +if __name__ == "__main__": |
| 302 | + sys.exit(main()) |
0 commit comments