Skip to content

Commit 2b7949c

Browse files
jaemzflemingmgoin
andauthored
AQLM CUDA support (#3287)
Co-authored-by: mgoin <[email protected]>
1 parent 62b5166 commit 2b7949c

File tree

14 files changed

+1592
-11
lines changed

14 files changed

+1592
-11
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ set(VLLM_EXT_SRC
173173

174174
if(VLLM_GPU_LANG STREQUAL "CUDA")
175175
list(APPEND VLLM_EXT_SRC
176+
"csrc/quantization/aqlm/gemm_kernels.cu"
176177
"csrc/quantization/awq/gemm_kernels.cu"
177178
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
178179
"csrc/custom_all_reduce.cu")
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
import argparse
2+
import os
3+
import sys
4+
from typing import Optional
5+
6+
import torch
7+
import torch.nn.functional as F
8+
9+
from vllm._C import ops
10+
from vllm.model_executor.layers.quantization.aqlm import (
11+
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
12+
optimized_dequantize_gemm)
13+
14+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
15+
16+
17+
def torch_mult(
18+
input: torch.Tensor, # [..., in_features]
19+
weights: torch.Tensor,
20+
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
21+
) -> torch.Tensor:
22+
output = F.linear(input, weights)
23+
return output
24+
25+
26+
def dequant_out_scale(
27+
input: torch.Tensor, # [..., in_features]
28+
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
29+
codebooks: torch.
30+
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
31+
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
32+
output_partition_sizes: torch.IntTensor,
33+
bias: Optional[torch.Tensor],
34+
) -> torch.Tensor:
35+
36+
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
37+
38+
if bias is None:
39+
output = F.linear(input, weights, bias)
40+
orig_shape = output.shape
41+
flattened_output = output.view(-1, output.size(-1))
42+
f_scales = scales.view(-1, scales.shape[0])
43+
b_scales = f_scales.expand(flattened_output.shape[0], -1)
44+
flattened_output *= b_scales
45+
return flattened_output.view(orig_shape)
46+
else:
47+
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
48+
-1, weights.shape[1])
49+
weights *= b_scales
50+
return F.linear(input, weights, bias)
51+
52+
53+
def dequant_weight_scale(
54+
input: torch.Tensor, # [..., in_features]
55+
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
56+
codebooks: torch.
57+
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
58+
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
59+
output_partition_sizes: torch.IntTensor,
60+
bias: Optional[torch.Tensor],
61+
) -> torch.Tensor:
62+
63+
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
64+
65+
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
66+
-1, weights.shape[1])
67+
weights *= b_scales
68+
return F.linear(input, weights, bias)
69+
70+
71+
def dequant_no_scale(
72+
input: torch.Tensor, # [..., in_features]
73+
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
74+
codebooks: torch.
75+
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
76+
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
77+
output_partition_sizes: torch.IntTensor,
78+
bias: Optional[torch.Tensor],
79+
) -> torch.Tensor:
80+
81+
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
82+
83+
return F.linear(input, weights, bias)
84+
85+
86+
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
87+
# the generic pytorch version.
88+
# Just visual comparison.
89+
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
90+
91+
n = parts.sum().item()
92+
93+
device = torch.device('cuda:0')
94+
95+
code_range = (1 << bits) // 2
96+
ingroups = 8
97+
98+
codes = torch.randint(-code_range,
99+
code_range,
100+
size=(n, k // ingroups, nbooks),
101+
dtype=get_int_dtype(bits),
102+
device=device)
103+
104+
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
105+
dtype=torch.float16,
106+
device=device)
107+
108+
count = 0
109+
for index in range(16):
110+
for i in range(8):
111+
for book in range(nbooks):
112+
codebooks[book, index, 0, i] = count * (10**book)
113+
count += 1
114+
115+
print("codes shape", codes.shape)
116+
117+
for i in range(16):
118+
for book in range(nbooks):
119+
codes[0, i, book] = i
120+
codes[0, -i, book] = i
121+
122+
weights = dequantize_weight(codes, codebooks, None)
123+
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
124+
125+
print("weights shape:", weights.shape)
126+
print("weights2 shape:", weights2.shape)
127+
128+
print("weights are:", weights)
129+
print("weights2 are:", weights2)
130+
131+
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
132+
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
133+
134+
print("last 128 weights are", weights[0, -128:])
135+
print("last 128 weights2 are:", weights2[0, -128:])
136+
137+
138+
def main():
139+
140+
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
141+
142+
# Add arguments
143+
parser.add_argument("--nbooks",
144+
type=int,
145+
default=1,
146+
help="Number of codebooks (default: 1)")
147+
parser.add_argument("--bits",
148+
type=int,
149+
default=16,
150+
help="Number of bits per code element (default: 16)")
151+
parser.add_argument(
152+
"--test",
153+
type=bool,
154+
default=False,
155+
help="Run the decompression/dequant tester rather than benchmarking "
156+
"(default: False)")
157+
158+
# Parse the arguments
159+
args = parser.parse_args()
160+
161+
# Extract values
162+
nbooks = args.nbooks
163+
bits = args.bits
164+
165+
if args.test:
166+
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
167+
return
168+
169+
# Otherwise, benchmark.
170+
methods = [
171+
ops.aqlm_gemm,
172+
dequant_out_scale,
173+
generic_dequantize_gemm,
174+
optimized_dequantize_gemm,
175+
dequant_weight_scale,
176+
torch_mult,
177+
dequant_no_scale,
178+
]
179+
180+
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
181+
print(f"writing benchmarks to file {filename}")
182+
with open(filename, "w") as f:
183+
sys.stdout = f
184+
185+
print('m | k | n | n parts', end='')
186+
for method in methods:
187+
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
188+
print('')
189+
190+
# These are reasonable prefill sizes.
191+
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
192+
(4096, (11008, 11008)), (11008, (4096, )))
193+
194+
# reasonable ranges for m.
195+
for m in [
196+
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
197+
128, 256, 512, 1024, 1536, 2048, 3072, 4096
198+
]:
199+
print(f'{m}', file=sys.__stdout__)
200+
for ksp in ksandpartions:
201+
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
202+
methods)
203+
204+
sys.stdout = sys.__stdout__
205+
206+
207+
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
208+
methods):
209+
210+
# I didn't see visible improvements from increasing these, but feel free :)
211+
num_warmup_trials = 1
212+
num_trials = 1
213+
214+
num_calls = 100
215+
216+
# warmup.
217+
for method in methods:
218+
for _ in range(num_warmup_trials):
219+
run_timing(
220+
num_calls=num_calls,
221+
m=m,
222+
k=k,
223+
parts=parts,
224+
nbooks=nbooks,
225+
bits=bits,
226+
method=method,
227+
)
228+
229+
n = parts.sum().item()
230+
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
231+
232+
for method in methods:
233+
best_time_us = 1e20
234+
for _ in range(num_trials):
235+
kernel_dur_ms = run_timing(
236+
num_calls=num_calls,
237+
m=m,
238+
k=k,
239+
parts=parts,
240+
nbooks=nbooks,
241+
bits=bits,
242+
method=method,
243+
)
244+
245+
kernel_dur_us = 1000 * kernel_dur_ms
246+
247+
if kernel_dur_us < best_time_us:
248+
best_time_us = kernel_dur_us
249+
250+
print(f' | {kernel_dur_us:.0f}', end='')
251+
252+
print('')
253+
254+
255+
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
256+
nbooks: int, bits: int, method) -> float:
257+
258+
n = parts.sum().item()
259+
260+
device = torch.device('cuda:0')
261+
262+
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
263+
264+
code_range = (1 << bits) // 2
265+
ingroups = 8
266+
267+
codes = torch.randint(-code_range,
268+
code_range,
269+
size=(n, k // ingroups, nbooks),
270+
dtype=get_int_dtype(bits),
271+
device=device)
272+
273+
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
274+
dtype=torch.float16,
275+
device=device)
276+
277+
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
278+
279+
# for comparison to just a pytorch mult.
280+
weights = torch.randn((n, k), dtype=torch.float16, device=device)
281+
282+
start_event = torch.cuda.Event(enable_timing=True)
283+
end_event = torch.cuda.Event(enable_timing=True)
284+
285+
start_event.record()
286+
287+
if method is torch_mult:
288+
for i in range(num_calls):
289+
torch_mult(input, weights, scales)
290+
else:
291+
for i in range(num_calls):
292+
method(input, codes, codebooks, scales, parts, None)
293+
294+
end_event.record()
295+
end_event.synchronize()
296+
297+
dur_ms = start_event.elapsed_time(end_event) / num_calls
298+
return dur_ms
299+
300+
301+
if __name__ == "__main__":
302+
sys.exit(main())

csrc/ops.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,21 @@ void gelu_fast(
8686
torch::Tensor& input);
8787

8888
#ifndef USE_ROCM
89+
torch::Tensor aqlm_gemm(
90+
const torch::Tensor& input,
91+
const torch::Tensor& codes,
92+
const torch::Tensor& codebooks,
93+
const torch::Tensor& scales,
94+
const torch::Tensor& codebook_partition_sizes,
95+
const std::optional<torch::Tensor>& bias
96+
);
97+
98+
torch::Tensor aqlm_dequant(
99+
const torch::Tensor& codes,
100+
const torch::Tensor& codebooks,
101+
const torch::Tensor& codebook_partition_sizes
102+
);
103+
89104
torch::Tensor awq_gemm(
90105
torch::Tensor _in_feats,
91106
torch::Tensor _kernel,

csrc/pybind.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6363

6464
// Quantization ops
6565
#ifndef USE_ROCM
66+
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
67+
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
6668
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
6769
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
6870
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");

0 commit comments

Comments
 (0)