Skip to content

Commit 7f45c2a

Browse files
committed
Quantize: Use FP16 costs, use shmem cost buffer down to 2bpw
1 parent 6fecf02 commit 7f45c2a

File tree

2 files changed

+69
-50
lines changed

2 files changed

+69
-50
lines changed

exllamav3/exllamav3_ext/quant/quantize.cu

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
#include "../util.h"
66
#include "../util.cuh"
77
#include "codebook.cuh"
8+
#include "exl3_devctx.cuh"
89
#include <cmath>
910

1011
#define NUM_THREADS 1024
12+
#define H_INF __ushort_as_half(0x7c00)
13+
#define H_NINF __ushort_as_half(0xfc00)
1114

1215
template <int K, int cb>
1316
__global__ __launch_bounds__(MIN(NUM_THREADS, 65536 >> K))
@@ -16,10 +19,13 @@ void quantize_tiles_kernel
1619
const float* __restrict__ input_tiles_ptr,
1720
float* __restrict__ output_tiles_ptr,
1821
uint16_t* __restrict__ output_indices_ptr,
19-
float* __restrict__ temp_costs_ptr,
22+
half* __restrict__ temp_costs_ptr,
2023
uint16_t* __restrict__ temp_edges_ptr
2124
)
2225
{
26+
extern __shared__ uint8_t shbuf[];
27+
uint8_t* sh = shbuf;
28+
2329
int tile_idx = blockIdx.x;
2430
int thread = threadIdx.x;
2531

@@ -32,46 +38,53 @@ void quantize_tiles_kernel
3238
uint16_t* output_indices = output_indices_ptr + 256 * tile_idx;
3339
uint16_t* temp_edges = temp_edges_ptr + 256 * edges * tile_idx;
3440

35-
// K >= 4 lets temp_costs fit in shmem, otherwise fall back to global temp buffer
36-
__shared__ float sh_temp_costs[K >= 4 ? 2 * 65536 >> K : 1];
37-
float* temp_costs = K >= 4 ? sh_temp_costs : temp_costs_ptr + 2 * edges * tile_idx;
38-
float* temp_costs_inc = temp_costs + edges;
41+
// Tile buffer
42+
half* sh_input_tile = (half*) sh; sh += 256 * sizeof(half);
43+
44+
half* sh_min = (half*) sh; sh += 32 * sizeof(half);
45+
int* sh_idx = (int*) sh; sh += 32 * sizeof(int);
46+
47+
// K >= mshk lets temp_costs fit in shmem, otherwise fall back to global temp buffer
48+
constexpr int mshk = 2;
49+
half* sh_temp_costs = (half*) sh;
50+
half* temp_costs = K >= mshk ? sh_temp_costs : temp_costs_ptr + 2 * edges * tile_idx;
51+
half* temp_costs_inc = temp_costs + edges;
3952

4053
// Fetch input tile to shmem
41-
__shared__ float sh_input_tile[256];
42-
if (thread < 256) sh_input_tile[thread] = input_tile[thread];
54+
if (thread < 256) sh_input_tile[thread] = __float2half_rn(input_tile[thread]);
4355
__syncthreads();
4456

4557
auto forward = [&](int roll, int pre_state)
4658
{
4759
int ri = roll % 256;
60+
half dh, err, min_err, w;
4861

4962
// temp_costs_inc[z] is the cost/cumulative error of an incoming edge from state (z & edge_mask)
50-
float* t = temp_costs;
63+
half* t = temp_costs;
5164
temp_costs = temp_costs_inc;
5265
temp_costs_inc = t;
5366

5467
for (int out_edge_idx = thread; out_edge_idx < edges; out_edge_idx += NUM_THREADS)
5568
{
56-
float w = sh_input_tile[ri];
69+
w = sh_input_tile[ri];
5770

5871
int state = out_edge_idx;
5972
int in_edge_idx = state >> K;
60-
float err = decode_3inst_f_diff<cb>(state, w);
61-
err = err * err;
62-
if (pre_state >= 0 && in_edge_idx != pre_state) err = 1e30f;
63-
float min_err = err;
73+
dh = __hsub(decode_3inst<cb>(state), w);
74+
err = __hmul(dh, dh);
75+
if (pre_state >= 0 && in_edge_idx != pre_state) err = H_INF;
76+
min_err = err;
6477
int min_in_edge = in_edge_idx;
6578

6679
#pragma unroll
6780
for (int k = 1; k < max_q; ++k)
6881
{
6982
state = (k << Kr) | out_edge_idx;
7083
in_edge_idx = state >> K;
71-
err = decode_3inst_f_diff<cb>(state, w);
72-
err = err * err;
73-
if (pre_state >= 0 && in_edge_idx != pre_state) err = 1e30f;
74-
if (err < min_err) { min_err = err; min_in_edge = in_edge_idx; }
84+
dh = __hsub(decode_3inst<cb>(state), w);
85+
err = __hmul(dh, dh);
86+
if (pre_state >= 0 && in_edge_idx != pre_state) err = H_INF;
87+
if (__hlt(err, min_err)) { min_err = err; min_in_edge = in_edge_idx; }
7588
}
7689

7790
temp_costs[out_edge_idx] = min_err;
@@ -93,23 +106,23 @@ void quantize_tiles_kernel
93106

94107
for (int out_edge_idx = thread; out_edge_idx < edges; out_edge_idx += NUM_THREADS)
95108
{
96-
float w = sh_input_tile[ri];
109+
w = sh_input_tile[ri];
97110

98111
int state = out_edge_idx;
99112
int in_edge_idx = state >> K;
100-
float err = decode_3inst_f_diff<cb>(state, w);
101-
err = err * err + temp_costs_inc[in_edge_idx];
102-
float min_err = err;
113+
dh = __hsub(decode_3inst<cb>(state), w);
114+
err = __hfma(dh, dh, temp_costs_inc[in_edge_idx]);
115+
min_err = err;
103116
int min_in_edge = in_edge_idx;
104117

105118
#pragma unroll
106119
for (int k = 1; k < max_q; ++k)
107120
{
108121
state = (k << Kr) | out_edge_idx;
109122
in_edge_idx = state >> K;
110-
err = decode_3inst_f_diff<cb>(state, w);
111-
err = err * err + temp_costs_inc[in_edge_idx];
112-
if (err < min_err) { min_err = err; min_in_edge = in_edge_idx; }
123+
dh = __hsub(decode_3inst<cb>(state), w);
124+
err = __hfma(dh, dh, temp_costs_inc[in_edge_idx]);
125+
if (__hlt(err, min_err)) { min_err = err; min_in_edge = in_edge_idx; }
113126
}
114127

115128
temp_costs[out_edge_idx] = min_err;
@@ -125,16 +138,13 @@ void quantize_tiles_kernel
125138
{
126139
// Find the final state with the lowest total cost. Return value is only valid in thread 0
127140

128-
float local_min = 1e30f;
141+
half local_min = H_INF;
129142
int local_idx = -1;
143+
#pragma unroll
130144
for (int e = threadIdx.x; e < edges; e += NUM_THREADS)
131145
{
132-
float v = temp_costs_inc[e];
133-
if (v < local_min)
134-
{
135-
local_min = v;
136-
local_idx = e;
137-
}
146+
half v = temp_costs_inc[e];
147+
if (__hlt(v, local_min)) { local_min = v; local_idx = e; }
138148
}
139149

140150
// Shuffle reduction
@@ -144,33 +154,30 @@ void quantize_tiles_kernel
144154
#pragma unroll
145155
for (int offset = 16; offset > 0; offset >>= 1)
146156
{
147-
float other_min = __shfl_down_sync(0xffffffff, local_min, offset, 32);
157+
half other_min = __shfl_down_sync(0xffffffff, local_min, offset, 32);
148158
int other_idx = __shfl_down_sync(0xffffffff, local_idx, offset, 32);
149-
if (other_min < local_min)
159+
if (__hlt(other_min, local_min))
150160
{
151161
local_min = other_min;
152162
local_idx = other_idx;
153163
}
154164
}
155165

156-
__shared__ float s_min[32];
157-
__shared__ int s_idx[32];
158-
159-
s_min[warp_id] = local_min;
160-
s_idx[warp_id] = local_idx;
166+
sh_min[warp_id] = local_min;
167+
sh_idx[warp_id] = local_idx;
161168
__syncthreads();
162169

163170
if (warp_id == 0)
164171
{
165-
local_min = lane_id * 32 < edges && thread < NUM_THREADS / 32 ? s_min[lane_id] : 1e31f;
166-
local_idx = thread < NUM_THREADS ? s_idx[lane_id] : 0;
172+
local_min = lane_id * 32 < edges && thread < NUM_THREADS / 32 ? sh_min[lane_id] : H_INF;
173+
local_idx = thread < NUM_THREADS ? sh_idx[lane_id] : 0;
167174

168175
#pragma unroll
169176
for (int offset = 16; offset > 0; offset >>= 1)
170177
{
171-
float other_min = __shfl_down_sync(0xffffffff, local_min, offset, 32);
178+
half other_min = __shfl_down_sync(0xffffffff, local_min, offset, 32);
172179
int other_idx = __shfl_down_sync(0xffffffff, local_idx, offset, 32);
173-
if (other_min < local_min)
180+
if (__hlt(other_min, local_min))
174181
{
175182
local_min = other_min;
176183
local_idx = other_idx;
@@ -206,10 +213,9 @@ void quantize_tiles_kernel
206213
}
207214

208215
// Broadcast to block
209-
__shared__ int broadcast;
210-
if (thread == 0) broadcast = edge;
216+
if (thread == 0) sh_idx[0] = edge;
211217
__syncthreads();
212-
edge = broadcast;
218+
edge = sh_idx[0];
213219

214220
return edge;
215221
};
@@ -270,8 +276,9 @@ void quantize_tiles
270276
int threads = MIN(NUM_THREADS, edges);
271277

272278
int num_tiles = input_tiles.size(0);
279+
if (!num_tiles) return;
273280

274-
TORCH_CHECK_DTYPE(temp_costs, kFloat);
281+
TORCH_CHECK_DTYPE(temp_costs, kHalf);
275282
TORCH_CHECK_DIM(temp_costs, 3);
276283
TORCH_CHECK_SIZE(temp_costs, 1, 2);
277284
TORCH_CHECK_SIZE(temp_costs, 2, edges);
@@ -281,7 +288,10 @@ void quantize_tiles
281288
TORCH_CHECK_SIZE(temp_edges, 1, 256);
282289
TORCH_CHECK_SIZE(temp_edges, 2, edges);
283290

284-
int max_batch_size = temp_costs.size(0);
291+
int device;
292+
cudaGetDevice(&device);
293+
int num_sms = DevCtx::instance().get_num_sms(device);
294+
int max_batch_size = MIN(temp_costs.size(0), num_sms);
285295

286296
int cb = 0;
287297
if (mcg) cb = 1;
@@ -295,13 +305,22 @@ void quantize_tiles
295305
const float* input_tiles_ptr = ((const float*) input_tiles.data_ptr()) + 256 * batch_i;
296306
float* output_tiles_ptr = ((float*) output_tiles.data_ptr()) + 256 * batch_i;
297307
uint16_t* output_indices_ptr = ((uint16_t*) output_indices.data_ptr()) + 256 * batch_i;
298-
float* temp_costs_ptr = (float*) temp_costs.data_ptr();
308+
half* temp_costs_ptr = (half*) temp_costs.data_ptr();
299309
uint16_t* temp_edges_ptr = (uint16_t*) temp_edges.data_ptr();
300310

301311
int bsz = batch_j - batch_i;
302312
int kernel_idx = K - 1 + 8 * cb;
313+
int shmem = 2 * (65536 >> K) * sizeof(half) + 512 + 64 + 128;
314+
315+
cudaFuncSetAttribute
316+
(
317+
quantize_tiles_kernel_instances[kernel_idx],
318+
cudaFuncAttributeMaxDynamicSharedMemorySize,
319+
shmem
320+
);
321+
cuda_check(cudaPeekAtLastError());
303322

304-
quantize_tiles_kernel_instances[kernel_idx]<<<bsz, threads, 0, stream>>>
323+
quantize_tiles_kernel_instances[kernel_idx]<<<bsz, threads, shmem, stream>>>
305324
(
306325
input_tiles_ptr,
307326
output_tiles_ptr,

exllamav3/modules/quant/exl3_lib/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def tensor_core_perm_i(device):
4545
@lru_cache
4646
def get_temp_buffers(device, K: int):
4747
max_batch_size = 256
48-
temp_costs = torch.zeros((max_batch_size, 2, 65536 >> K), dtype = torch.float, device = device)
48+
temp_costs = torch.zeros((max_batch_size, 2, 65536 >> K), dtype = torch.half, device = device)
4949
temp_edges = torch.zeros((max_batch_size, 256, 65536 >> K), dtype = torch.short, device = device)
5050
return temp_costs, temp_edges
5151

0 commit comments

Comments
 (0)