Skip to content

Commit 0fbfc4b

Browse files
Add GPTQ support (#916)
1 parent c06170c commit 0fbfc4b

35 files changed

+1781
-81
lines changed

benchmarks/benchmark_latency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
8484
parser.add_argument('--tokenizer', type=str, default=None)
8585
parser.add_argument('--quantization',
8686
'-q',
87-
choices=['awq', 'squeezellm', None],
87+
choices=['awq', 'gptq', 'squeezellm', None],
8888
default=None)
8989
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
9090
parser.add_argument('--input-len', type=int, default=32)

benchmarks/benchmark_throughput.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def main(args: argparse.Namespace):
244244
parser.add_argument("--tokenizer", type=str, default=None)
245245
parser.add_argument('--quantization',
246246
'-q',
247-
choices=['awq', 'squeezellm', None],
247+
choices=['awq', 'gptq', 'squeezellm', None],
248248
default=None)
249249
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
250250
parser.add_argument("--n",

csrc/ops.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,15 @@ void squeezellm_gemm(
7777
torch::Tensor mat,
7878
torch::Tensor mul,
7979
torch::Tensor lookup_table);
80+
81+
torch::Tensor gptq_gemm(
82+
torch::Tensor a,
83+
torch::Tensor b_q_weight,
84+
torch::Tensor b_gptq_qzeros,
85+
torch::Tensor b_gptq_scales,
86+
torch::Tensor b_g_idx,
87+
bool use_exllama);
88+
89+
void gptq_shuffle(
90+
torch::Tensor q_weight,
91+
torch::Tensor q_perm);

csrc/pybind.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5252
// Quantization ops
5353
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
5454
#endif
55-
56-
55+
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
56+
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
5757
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
5858

5959
// Cache ops

csrc/quantization/gptq/compat.cuh

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
Copied from https://github.com/turboderp/exllamav2
3+
*/
4+
5+
#ifndef _compat_cuh
6+
#define _compat_cuh
7+
8+
namespace vllm {
9+
namespace gptq {
10+
// atomicAdd for half types, to support CC < 7.x
11+
12+
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
13+
{
14+
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
15+
unsigned int old = *address_as_ui;
16+
unsigned int assumed;
17+
18+
do
19+
{
20+
assumed = old;
21+
__half_raw hsum;
22+
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
23+
half tmpres = __hadd(hsum, val);
24+
hsum = __half_raw(tmpres);
25+
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
26+
old = atomicCAS(address_as_ui, assumed, old);
27+
}
28+
while (assumed != old);
29+
}
30+
31+
// atomicAdd for half2 types
32+
33+
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
34+
{
35+
unsigned int* address_as_ui = (unsigned int*)address;
36+
unsigned int old = *address_as_ui;
37+
unsigned int assumed;
38+
do
39+
{
40+
assumed = old;
41+
half2 old_val = *((half2*)&old);
42+
half2 new_val = __hadd2(old_val, val);
43+
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
44+
}
45+
while (assumed != old);
46+
}
47+
48+
//
49+
50+
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
51+
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
52+
53+
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
54+
55+
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
56+
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
57+
#endif
58+
59+
#endif
60+
#endif
61+
62+
} // namespace gptq
63+
} // namespace vllm
64+
#endif
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
3+
*/
4+
5+
#ifndef _matrix_view_cuh
6+
#define _matrix_view_cuh
7+
8+
#include <cuda_runtime.h>
9+
#include <cuda_fp16.h>
10+
11+
#include "qdq_util.cuh"
12+
13+
namespace vllm {
14+
namespace gptq {
15+
16+
class MatrixView_half
17+
{
18+
public:
19+
const half* data;
20+
const int height;
21+
const int width;
22+
23+
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
24+
: data(data), height(height), width(width)
25+
{ }
26+
27+
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
28+
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
29+
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
30+
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
31+
32+
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
33+
{
34+
half2* ptr = (half2*) item_ptr(row, column);
35+
half2 i01 = ptr[0];
36+
half2 i23 = ptr[1];
37+
items[0] = __low2half(i01);
38+
items[1] = __high2half(i01);
39+
items[2] = __low2half(i23);
40+
items[3] = __high2half(i23);
41+
}
42+
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
43+
{
44+
half2* ptr = (half2*)item_ptr(row, column);
45+
half2 i01 = ptr[0];
46+
half2 i23 = ptr[1];
47+
items[0] = __half2float(__low2half(i01));
48+
items[1] = __half2float(__high2half(i01));
49+
items[2] = __half2float(__low2half(i23));
50+
items[3] = __half2float(__high2half(i23));
51+
}
52+
53+
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
54+
{
55+
half2* ptr = (half2*)item_ptr(row, column);
56+
half2 i01 = ptr[0];
57+
half2 i23 = ptr[1];
58+
items[0] = __half2half2(__low2half(i01));
59+
items[1] = __half2half2(__high2half(i01));
60+
items[2] = __half2half2(__low2half(i23));
61+
items[3] = __half2half2(__high2half(i23));
62+
}
63+
};
64+
65+
class MatrixView_half_rw
66+
{
67+
public:
68+
half* data;
69+
const int height;
70+
const int width;
71+
72+
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
73+
: data(data), height(height), width(width)
74+
{ }
75+
76+
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
77+
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
78+
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
79+
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
80+
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
81+
82+
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
83+
{
84+
half2 v01 = __halves2half2(v0, v1);
85+
half2 v23 = __halves2half2(v2, v3);
86+
half2* ptr = (half2*) item_ptr(row, column);
87+
ptr[0] = v01;
88+
ptr[1] = v23;
89+
}
90+
};
91+
92+
class MatrixView_q4_row
93+
{
94+
public:
95+
const uint32_t* data;
96+
const int height;
97+
const int width;
98+
99+
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
100+
: data(data), height(height), width(width)
101+
{ }
102+
103+
__device__ __forceinline__ int item(int row, int column) const
104+
{
105+
int shift = (column & 0x07) * 4;
106+
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
107+
}
108+
109+
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
110+
{
111+
int shift = (column & 0x07) * 4;
112+
uint32_t d = data[row * width / 8 + column / 8] >> shift;
113+
items[0] = d & 0x0f;
114+
items[1] = (d >> 4) & 0x0f;
115+
}
116+
117+
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
118+
{
119+
int shift = (column & 0x07) * 4;
120+
uint32_t d = data[row * width / 8 + column / 8] >> shift;
121+
items[0] = d & 0x0f;
122+
items[1] = (d >> 4) & 0x0f;
123+
items[2] = (d >> 8) & 0x0f;
124+
items[3] = (d >> 12) & 0x0f;
125+
}
126+
};
127+
128+
class MatrixView_q4_column
129+
{
130+
public:
131+
const uint32_t* data;
132+
const int height;
133+
const int width;
134+
135+
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
136+
: data(data), height(height), width(width)
137+
{ }
138+
139+
__device__ __forceinline__ int item(int row, int column) const
140+
{
141+
int shift = (row & 0x07) * 4;
142+
return (data[row / 8 * width + column] >> shift) & 0x0f;
143+
}
144+
145+
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
146+
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
147+
};
148+
149+
} // namespace gptq
150+
} // namespace vllm
151+
#endif

0 commit comments

Comments
 (0)