Skip to content

Commit beb89f6

Browse files
AWQ: Up to 2.66x higher throughput (#2566)
1 parent 390b495 commit beb89f6

File tree

4 files changed

+127
-1
lines changed

4 files changed

+127
-1
lines changed

csrc/ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ torch::Tensor awq_gemm(
7070
torch::Tensor _scaling_factors,
7171
torch::Tensor _zeros,
7272
int split_k_iters);
73+
74+
torch::Tensor awq_dequantize(
75+
torch::Tensor _kernel,
76+
torch::Tensor _scaling_factors,
77+
torch::Tensor _zeros,
78+
int split_k_iters,
79+
int thx,
80+
int thy);
7381
#endif
7482

7583
void squeezellm_gemm(

csrc/pybind.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5151
#ifndef USE_ROCM
5252
// Quantization ops
5353
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
54+
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
5455
#endif
5556
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
5657
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");

csrc/quantization/awq/gemm_kernels.cu

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
493493
#endif
494494
}
495495

496+
__global__ void __launch_bounds__(64) dequantize_weights(
497+
int* __restrict__ B,
498+
half* __restrict__ scaling_factors,
499+
int* __restrict__ zeros,
500+
half* __restrict__ C,
501+
int G
502+
)
503+
{
504+
int j_factors1 = 4;
505+
int row_stride2 = 4;
506+
int split_k_iters = 1;
507+
static constexpr uint32_t ZERO = 0x0;
508+
half B_shared[32 * (128 + 8)];
509+
510+
half* B_shared_ptr2 = B_shared;
511+
512+
half B_shared_warp[32];
513+
int OC = 512;
514+
515+
int N = blockDim.x * gridDim.x; // 2
516+
int col = (blockIdx.x * blockDim.x + threadIdx.x);
517+
int row = blockIdx.y * blockDim.y + threadIdx.y;
518+
int index1 = 8 * col + 8 * row * N;
519+
half* C_ptr2 = C + index1;
520+
521+
int index2 = col + row * N;
522+
int* B_ptr2 = B + index2;
523+
524+
int index3 = col + (int)(row / G) * N;
525+
int* zeros_ptr2 = zeros + index3;
526+
int index4 = 8 * col + (int)(row / G) * N * 8;
527+
half* scaling_factors_ptr2 = scaling_factors + index4;
528+
529+
530+
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
531+
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
532+
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
533+
int j=0;
534+
535+
uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j);
536+
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
537+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
538+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
539+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
540+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
541+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
542+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
543+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
544+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
545+
546+
*(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16;
547+
548+
for (int i=0; i<8; ++i) {
549+
*(C_ptr2 + i) = B_shared[i];
550+
}
551+
}
552+
496553
} // namespace awq
497554
} // namespace vllm
498555

556+
torch::Tensor awq_dequantize(
557+
torch::Tensor _kernel,
558+
torch::Tensor _scaling_factors,
559+
torch::Tensor _zeros,
560+
int split_k_iters,
561+
int thx,
562+
int thy)
563+
{
564+
int in_c = _kernel.size(0);
565+
int qout_c = _kernel.size(1);
566+
int out_c = qout_c * 8;
567+
int G = in_c / _scaling_factors.size(0);
568+
569+
int x_thread = thx;
570+
int y_thread = thy;
571+
572+
int x_blocks = 1;
573+
int y_blocks = 1;
574+
if (thx==0) {
575+
x_thread = qout_c;
576+
}
577+
if (thy==0) {
578+
y_thread = in_c;
579+
}
580+
if (thx==0 && thy==0) {
581+
x_thread = 8;
582+
y_thread = 8;
583+
x_blocks = (int)(qout_c / 8);
584+
y_blocks = (int)(in_c / 8);
585+
}
586+
587+
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
588+
589+
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
590+
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
591+
592+
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
593+
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
594+
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
595+
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
596+
597+
dim3 num_blocks(x_blocks, y_blocks);
598+
dim3 threads_per_block(x_thread, y_thread);
599+
600+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
601+
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
602+
kernel, scaling_factors, zeros, de_kernel, G);
603+
604+
return _de_kernel;
605+
}
606+
499607
// in_feats: M, IC [float16]
500608
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
501609
// scaling_factors: IC // G, OC [float16]

vllm/model_executor/layers/quantization/awq.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,16 @@ def apply_weights(self,
153153
pack_factor = self.quant_config.pack_factor
154154
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
155155
reshaped_x = x.reshape(-1, x.shape[-1])
156-
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
156+
157+
# num_tokens >= threshold
158+
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
159+
160+
if FP16_MATMUL_HEURISTIC_CONDITION:
161+
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
162+
out = torch.matmul(reshaped_x, out)
163+
else:
164+
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
165+
pack_factor)
157166
if bias is not None:
158167
out = out + bias
159168
return out.reshape(out_shape)

0 commit comments

Comments
 (0)