55#include " ../util.h"
66#include " ../util.cuh"
77#include " codebook.cuh"
8+ #include " exl3_devctx.cuh"
89#include < cmath>
910
1011#define NUM_THREADS 1024
12+ #define H_INF __ushort_as_half (0x7c00 )
13+ #define H_NINF __ushort_as_half (0xfc00 )
1114
1215template <int K, int cb>
1316__global__ __launch_bounds__(MIN(NUM_THREADS, 65536 >> K))
@@ -16,10 +19,13 @@ void quantize_tiles_kernel
1619 const float * __restrict__ input_tiles_ptr,
1720 float * __restrict__ output_tiles_ptr,
1821 uint16_t * __restrict__ output_indices_ptr,
19- float * __restrict__ temp_costs_ptr,
22+ half * __restrict__ temp_costs_ptr,
2023 uint16_t * __restrict__ temp_edges_ptr
2124)
2225{
26+ extern __shared__ uint8_t shbuf[];
27+ uint8_t * sh = shbuf;
28+
2329 int tile_idx = blockIdx .x ;
2430 int thread = threadIdx .x ;
2531
@@ -32,46 +38,53 @@ void quantize_tiles_kernel
3238 uint16_t * output_indices = output_indices_ptr + 256 * tile_idx;
3339 uint16_t * temp_edges = temp_edges_ptr + 256 * edges * tile_idx;
3440
35- // K >= 4 lets temp_costs fit in shmem, otherwise fall back to global temp buffer
36- __shared__ float sh_temp_costs[K >= 4 ? 2 * 65536 >> K : 1 ];
37- float * temp_costs = K >= 4 ? sh_temp_costs : temp_costs_ptr + 2 * edges * tile_idx;
38- float * temp_costs_inc = temp_costs + edges;
41+ // Tile buffer
42+ half* sh_input_tile = (half*) sh; sh += 256 * sizeof (half);
43+
44+ half* sh_min = (half*) sh; sh += 32 * sizeof (half);
45+ int * sh_idx = (int *) sh; sh += 32 * sizeof (int );
46+
47+ // K >= mshk lets temp_costs fit in shmem, otherwise fall back to global temp buffer
48+ constexpr int mshk = 2 ;
49+ half* sh_temp_costs = (half*) sh;
50+ half* temp_costs = K >= mshk ? sh_temp_costs : temp_costs_ptr + 2 * edges * tile_idx;
51+ half* temp_costs_inc = temp_costs + edges;
3952
4053 // Fetch input tile to shmem
41- __shared__ float sh_input_tile[256 ];
42- if (thread < 256 ) sh_input_tile[thread] = input_tile[thread];
54+ if (thread < 256 ) sh_input_tile[thread] = __float2half_rn (input_tile[thread]);
4355 __syncthreads ();
4456
4557 auto forward = [&](int roll, int pre_state)
4658 {
4759 int ri = roll % 256 ;
60+ half dh, err, min_err, w;
4861
4962 // temp_costs_inc[z] is the cost/cumulative error of an incoming edge from state (z & edge_mask)
50- float * t = temp_costs;
63+ half * t = temp_costs;
5164 temp_costs = temp_costs_inc;
5265 temp_costs_inc = t;
5366
5467 for (int out_edge_idx = thread; out_edge_idx < edges; out_edge_idx += NUM_THREADS)
5568 {
56- float w = sh_input_tile[ri];
69+ w = sh_input_tile[ri];
5770
5871 int state = out_edge_idx;
5972 int in_edge_idx = state >> K;
60- float err = decode_3inst_f_diff <cb>(state, w);
61- err = err * err ;
62- if (pre_state >= 0 && in_edge_idx != pre_state) err = 1e30f ;
63- float min_err = err;
73+ dh = __hsub (decode_3inst <cb>(state) , w);
74+ err = __hmul (dh, dh) ;
75+ if (pre_state >= 0 && in_edge_idx != pre_state) err = H_INF ;
76+ min_err = err;
6477 int min_in_edge = in_edge_idx;
6578
6679 #pragma unroll
6780 for (int k = 1 ; k < max_q; ++k)
6881 {
6982 state = (k << Kr) | out_edge_idx;
7083 in_edge_idx = state >> K;
71- err = decode_3inst_f_diff <cb>(state, w);
72- err = err * err ;
73- if (pre_state >= 0 && in_edge_idx != pre_state) err = 1e30f ;
74- if (err < min_err) { min_err = err; min_in_edge = in_edge_idx; }
84+ dh = __hsub (decode_3inst <cb>(state) , w);
85+ err = __hmul (dh, dh) ;
86+ if (pre_state >= 0 && in_edge_idx != pre_state) err = H_INF ;
87+ if (__hlt ( err, min_err) ) { min_err = err; min_in_edge = in_edge_idx; }
7588 }
7689
7790 temp_costs[out_edge_idx] = min_err;
@@ -93,23 +106,23 @@ void quantize_tiles_kernel
93106
94107 for (int out_edge_idx = thread; out_edge_idx < edges; out_edge_idx += NUM_THREADS)
95108 {
96- float w = sh_input_tile[ri];
109+ w = sh_input_tile[ri];
97110
98111 int state = out_edge_idx;
99112 int in_edge_idx = state >> K;
100- float err = decode_3inst_f_diff <cb>(state, w);
101- err = err * err + temp_costs_inc[in_edge_idx];
102- float min_err = err;
113+ dh = __hsub (decode_3inst <cb>(state) , w);
114+ err = __hfma (dh, dh, temp_costs_inc[in_edge_idx]) ;
115+ min_err = err;
103116 int min_in_edge = in_edge_idx;
104117
105118 #pragma unroll
106119 for (int k = 1 ; k < max_q; ++k)
107120 {
108121 state = (k << Kr) | out_edge_idx;
109122 in_edge_idx = state >> K;
110- err = decode_3inst_f_diff <cb>(state, w);
111- err = err * err + temp_costs_inc[in_edge_idx];
112- if (err < min_err) { min_err = err; min_in_edge = in_edge_idx; }
123+ dh = __hsub (decode_3inst <cb>(state) , w);
124+ err = __hfma (dh, dh, temp_costs_inc[in_edge_idx]) ;
125+ if (__hlt ( err, min_err) ) { min_err = err; min_in_edge = in_edge_idx; }
113126 }
114127
115128 temp_costs[out_edge_idx] = min_err;
@@ -125,16 +138,13 @@ void quantize_tiles_kernel
125138 {
126139 // Find the final state with the lowest total cost. Return value is only valid in thread 0
127140
128- float local_min = 1e30f ;
141+ half local_min = H_INF ;
129142 int local_idx = -1 ;
143+ #pragma unroll
130144 for (int e = threadIdx .x ; e < edges; e += NUM_THREADS)
131145 {
132- float v = temp_costs_inc[e];
133- if (v < local_min)
134- {
135- local_min = v;
136- local_idx = e;
137- }
146+ half v = temp_costs_inc[e];
147+ if (__hlt (v, local_min)) { local_min = v; local_idx = e; }
138148 }
139149
140150 // Shuffle reduction
@@ -144,33 +154,30 @@ void quantize_tiles_kernel
144154 #pragma unroll
145155 for (int offset = 16 ; offset > 0 ; offset >>= 1 )
146156 {
147- float other_min = __shfl_down_sync (0xffffffff , local_min, offset, 32 );
157+ half other_min = __shfl_down_sync (0xffffffff , local_min, offset, 32 );
148158 int other_idx = __shfl_down_sync (0xffffffff , local_idx, offset, 32 );
149- if (other_min < local_min)
159+ if (__hlt ( other_min, local_min) )
150160 {
151161 local_min = other_min;
152162 local_idx = other_idx;
153163 }
154164 }
155165
156- __shared__ float s_min[32 ];
157- __shared__ int s_idx[32 ];
158-
159- s_min[warp_id] = local_min;
160- s_idx[warp_id] = local_idx;
166+ sh_min[warp_id] = local_min;
167+ sh_idx[warp_id] = local_idx;
161168 __syncthreads ();
162169
163170 if (warp_id == 0 )
164171 {
165- local_min = lane_id * 32 < edges && thread < NUM_THREADS / 32 ? s_min [lane_id] : 1e31f ;
166- local_idx = thread < NUM_THREADS ? s_idx [lane_id] : 0 ;
172+ local_min = lane_id * 32 < edges && thread < NUM_THREADS / 32 ? sh_min [lane_id] : H_INF ;
173+ local_idx = thread < NUM_THREADS ? sh_idx [lane_id] : 0 ;
167174
168175 #pragma unroll
169176 for (int offset = 16 ; offset > 0 ; offset >>= 1 )
170177 {
171- float other_min = __shfl_down_sync (0xffffffff , local_min, offset, 32 );
178+ half other_min = __shfl_down_sync (0xffffffff , local_min, offset, 32 );
172179 int other_idx = __shfl_down_sync (0xffffffff , local_idx, offset, 32 );
173- if (other_min < local_min)
180+ if (__hlt ( other_min, local_min) )
174181 {
175182 local_min = other_min;
176183 local_idx = other_idx;
@@ -206,10 +213,9 @@ void quantize_tiles_kernel
206213 }
207214
208215 // Broadcast to block
209- __shared__ int broadcast;
210- if (thread == 0 ) broadcast = edge;
216+ if (thread == 0 ) sh_idx[0 ] = edge;
211217 __syncthreads ();
212- edge = broadcast ;
218+ edge = sh_idx[ 0 ] ;
213219
214220 return edge;
215221 };
@@ -270,8 +276,9 @@ void quantize_tiles
270276 int threads = MIN (NUM_THREADS, edges);
271277
272278 int num_tiles = input_tiles.size (0 );
279+ if (!num_tiles) return ;
273280
274- TORCH_CHECK_DTYPE (temp_costs, kFloat );
281+ TORCH_CHECK_DTYPE (temp_costs, kHalf );
275282 TORCH_CHECK_DIM (temp_costs, 3 );
276283 TORCH_CHECK_SIZE (temp_costs, 1 , 2 );
277284 TORCH_CHECK_SIZE (temp_costs, 2 , edges);
@@ -281,7 +288,10 @@ void quantize_tiles
281288 TORCH_CHECK_SIZE (temp_edges, 1 , 256 );
282289 TORCH_CHECK_SIZE (temp_edges, 2 , edges);
283290
284- int max_batch_size = temp_costs.size (0 );
291+ int device;
292+ cudaGetDevice (&device);
293+ int num_sms = DevCtx::instance ().get_num_sms (device);
294+ int max_batch_size = MIN (temp_costs.size (0 ), num_sms);
285295
286296 int cb = 0 ;
287297 if (mcg) cb = 1 ;
@@ -295,13 +305,22 @@ void quantize_tiles
295305 const float * input_tiles_ptr = ((const float *) input_tiles.data_ptr ()) + 256 * batch_i;
296306 float * output_tiles_ptr = ((float *) output_tiles.data_ptr ()) + 256 * batch_i;
297307 uint16_t * output_indices_ptr = ((uint16_t *) output_indices.data_ptr ()) + 256 * batch_i;
298- float * temp_costs_ptr = (float *) temp_costs.data_ptr ();
308+ half * temp_costs_ptr = (half *) temp_costs.data_ptr ();
299309 uint16_t * temp_edges_ptr = (uint16_t *) temp_edges.data_ptr ();
300310
301311 int bsz = batch_j - batch_i;
302312 int kernel_idx = K - 1 + 8 * cb;
313+ int shmem = 2 * (65536 >> K) * sizeof (half) + 512 + 64 + 128 ;
314+
315+ cudaFuncSetAttribute
316+ (
317+ quantize_tiles_kernel_instances[kernel_idx],
318+ cudaFuncAttributeMaxDynamicSharedMemorySize,
319+ shmem
320+ );
321+ cuda_check (cudaPeekAtLastError ());
303322
304- quantize_tiles_kernel_instances[kernel_idx]<<<bsz, threads, 0 , stream>>>
323+ quantize_tiles_kernel_instances[kernel_idx]<<<bsz, threads, shmem , stream>>>
305324 (
306325 input_tiles_ptr,
307326 output_tiles_ptr,
0 commit comments