Skip to content

Commit 4829ea4

Browse files
committed
Rework GEMM kernel tuning
1 parent c3cae87 commit 4829ea4

File tree

16 files changed

+14037
-77
lines changed

16 files changed

+14037
-77
lines changed

exllamav3/exllamav3_ext/bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "quant/exl3_gemm.cuh"
2424
#include "quant/exl3_kernel_map.cuh"
2525
#include "quant/util.cuh"
26+
#include "quant/exl3_devctx.cuh"
2627

2728
#include "generator/strings.h"
2829
#include "generator/sampling_basic.cuh"
@@ -87,6 +88,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
8788
m.def("exl3_gemm", &exl3_gemm, "exl3_gemm");
8889
m.def("exl3_gemm_num_kernel_shapes", &exl3_gemm_num_kernel_shapes, "exl3_gemm_num_kernel_shapes");
8990
m.def("exl3_gemm_shape_compat", &exl3_gemm_shape_compat, "exl3_gemm_shape_compat");
91+
m.def("g_get_cc", &g_get_cc, "g_get_cc");
92+
m.def("g_get_num_sms", &g_get_num_sms, "g_get_num_sms");
9093
m.def("exl3_mgemm", &exl3_mgemm, "exl3_mgemm");
9194
m.def("hgemm", &hgemm, "hgemm");
9295
m.def("rope", &rope, "rope");

exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ void BC_BlockSparseMLP::run_bsz1
8585
gate_mcg_mult,
8686
gate_mul1_mult,
8787
min_expert,
88-
max_expert
88+
max_expert,
89+
0
8990
);
9091

9192
exl3_mgemm(
@@ -102,7 +103,8 @@ void BC_BlockSparseMLP::run_bsz1
102103
up_mcg_mult,
103104
up_mul1_mult,
104105
min_expert,
105-
max_expert
106+
max_expert,
107+
0
106108
);
107109

108110
if (act_silu)
@@ -124,7 +126,8 @@ void BC_BlockSparseMLP::run_bsz1
124126
down_mcg_mult,
125127
down_mul1_mult,
126128
min_expert,
127-
max_expert
129+
max_expert,
130+
0
128131
);
129132

130133
if (shared_experts)

exllamav3/exllamav3_ext/libtorch/linear.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ void BC_LinearEXL3::run(const at::Tensor& x, at::Tensor& y)
3131
{
3232
if (x.numel() == x.size(-1))
3333
{
34-
exl3_gemm(x, trellis, y, suh, xh, svh, -1, mcg_mult, mul1_mult);
34+
exl3_gemm(x, trellis, y, suh, xh, svh, -1, mcg_mult, mul1_mult, 0);
3535
}
3636
else
3737
{
3838
at::Tensor xh_ = at::empty_like(x);
39-
exl3_gemm(x, trellis, y, suh, xh_, svh, -1, mcg_mult, mul1_mult);
39+
exl3_gemm(x, trellis, y, suh, xh_, svh, -1, mcg_mult, mul1_mult, 0);
4040
}
4141

4242
if (bias) y.add_(bias.value());

exllamav3/exllamav3_ext/libtorch/mlp.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ void BC_GatedMLP::run_bsz1
3131
gu_mcg_mult,
3232
gu_mul1_mult,
3333
-1,
34-
-1
34+
-1,
35+
0
3536
);
3637

3738
at::Tensor g = gu.select(0, 0).unsqueeze(0);

exllamav3/exllamav3_ext/quant/exl3_devctx.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,14 @@ int* DevCtx::get_locks(int device)
5555
cudaMemset(locks[device], 0, MAX_TILES_C * sizeof(int));
5656
}
5757
return (int*) locks[device];
58+
}
59+
60+
int g_get_cc(int device)
61+
{
62+
return DevCtx::instance().get_cc(device);
63+
}
64+
65+
int g_get_num_sms(int device)
66+
{
67+
return DevCtx::instance().get_num_sms(device);
5868
}

exllamav3/exllamav3_ext/quant/exl3_devctx.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
// Max allowable output size, in tiles. Used to allocate global lock buffer per device for sync across threadblocks
77
#define MAX_TILES_C (1024 * 1024)
88

9+
// Treat hopper and blackwell as same arch for now
910
#define MAX_DEVICES 32
1011
#define CC_OLD 1
1112
#define CC_AMPERE 2
1213
#define CC_ADA 3
1314
#define CC_HOPPER 4
14-
#define CC_BLACKWELL 5
15+
#define CC_BLACKWELL 4
1516

1617
// Singleton to manage context for each device. Stores device attributes and a large-enough lock buffer per device
1718
class DevCtx
@@ -32,4 +33,7 @@ private:
3233
DevCtx() = default;
3334
DevCtx(const DevCtx&) = delete;
3435
DevCtx& operator=(const DevCtx&) = delete;
35-
};
36+
};
37+
38+
int g_get_cc(int device);
39+
int g_get_num_sms(int device);

exllamav3/exllamav3_ext/quant/exl3_gemm.cu

Lines changed: 81 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@ namespace cg = cooperative_groups;
1212
#include "exl3_devctx.cuh"
1313
#include <set>
1414

15+
#define NEW_TUNE_GEMM
16+
#define NEW_TUNE_MGEMM
17+
18+
int exl3_gemm_tilesize_k_g[] = {EXL3_GEMM_TILESIZE_K};
19+
int exl3_gemm_tilesize_n_g[] = {EXL3_GEMM_TILESIZE_N};
20+
1521
/*
1622
EXL3 matmul, A @ B -> C
1723
1824
- A: row-major A tensor, shape (m, k), dtype float16, contiguous
19-
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*bits), dtype uint16
25+
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*K), dtype uint16
2026
- C: empty row-major C tensor, shape (m, n), dtype float16 or float32, contiguous. Does not need to be zero-initialized
2127
- suh: optional, packed input scales/flips, shape (k//16), dtype float16
2228
- A_had: required if suh given, may be reference to A, temporary storage for input transform, size and dtype as A
@@ -39,7 +45,8 @@ int exl3_gemm
3945
const c10::optional<at::Tensor>& svh,
4046
int force_shape_idx,
4147
uint32_t mcg_mult,
42-
uint32_t mul1_mult
48+
uint32_t mul1_mult,
49+
int force_num_sms
4350
)
4451
{
4552
const at::cuda::OptionalCUDAGuard device_guard(A.device());
@@ -48,7 +55,7 @@ int exl3_gemm
4855
TORCH_CHECK_DIM(B, 3);
4956
TORCH_CHECK_SHAPES(A, -1, B, 0, 16);
5057
TORCH_CHECK_SHAPES(C, -1, B, 1, 16);
51-
// TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
58+
// TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
5259
TORCH_CHECK_DTYPE(A, kHalf);
5360
TORCH_CHECK_DTYPE(B, kShort);
5461
bool c_fp32 = C.dtype() == at::kFloat;
@@ -59,26 +66,26 @@ int exl3_gemm
5966
half* A_had_ptr = nullptr;
6067
if (suh_ptr)
6168
{
62-
// TORCH_CHECK_SHAPES(suh.value(), 0, A, 1, 1);
69+
// TORCH_CHECK_SHAPES(suh.value(), 0, A, 1, 1);
6370
A_had_ptr = (half*) OPTPTR(A_had);
64-
// TORCH_CHECK(A_had_ptr, "Must supply A_had with suh");
65-
// TORCH_CHECK_SHAPES_FULL(A_had.value(), A);
71+
// TORCH_CHECK(A_had_ptr, "Must supply A_had with suh");
72+
// TORCH_CHECK_SHAPES_FULL(A_had.value(), A);
6673
}
6774

6875
// Get SV, optionally
6976
const half* svh_ptr = (const half*) OPTPTR(svh);
70-
// if (svh_ptr)
71-
// TORCH_CHECK_SHAPES(svh.value(), 0, B, 1, 16);
77+
// if (svh_ptr)
78+
// TORCH_CHECK_SHAPES(svh.value(), 0, B, 1, 16);
7279

7380
// Device properties
7481
int device;
7582
cudaGetDevice(&device);
76-
int num_sms = DevCtx::instance().get_num_sms(device);
83+
int num_sms = force_num_sms ? force_num_sms : DevCtx::instance().get_num_sms(device);
7784
int cc = DevCtx::instance().get_cc(device);
7885
int* locks = DevCtx::instance().get_locks(device);
7986

8087
// Dispatch
81-
int bits = B.size(2) / 16;
88+
int K = B.size(2) / 16;
8289
const half* A_ptr = (const half*) A.data_ptr();
8390
const uint16_t* B_ptr = (const uint16_t*) B.data_ptr();
8491
void* C_ptr = (void*) C.data_ptr();
@@ -96,21 +103,33 @@ int exl3_gemm
96103
if (mcg_mult) { cb = 1; mult = mcg_mult; }
97104
if (mul1_mult) { cb = 2; mult = mul1_mult; }
98105

99-
int selected_shape;
100106
int block_dim;
101-
fp_exl3_gemm_kernel kernel = select_exl3_gemm_kernel
102-
(
103-
cc, size_m, size_k, size_n, bits, c_fp32,
104-
force_shape_idx, &block_dim, &selected_shape,
105-
&num_sms, cb
106-
);
107-
if (!kernel) return 0;
107+
int shape_idx;
108+
fp_exl3_gemm_kernel kernel;
109+
110+
#ifndef NEW_TUNE_GEMM
111+
kernel = select_exl3_gemm_kernel
112+
(
113+
cc, size_m, size_k, size_n, K, c_fp32,
114+
force_shape_idx, &block_dim, &shape_idx,
115+
&num_sms, cb
116+
);
117+
if (!kernel) return 0;
118+
#else
119+
TResult* tr = select_exl3_gemm_mgemm_kernel_new(cc, size_m, size_k, size_n, K, c_fp32, force_shape_idx, force_num_sms, cb);
120+
if (!tr) return 0;
121+
num_sms = MIN(num_sms, tr->num_sms);
122+
kernel = tr->kernel;
123+
block_dim = tr->block_dim;
124+
shape_idx = tr->shape_idx;
125+
#endif
108126

109127
// Launch
110-
if (kernel_attr_set[device].find((void*)kernel) == kernel_attr_set[device].end())
128+
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
111129
{
112130
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
113-
kernel_attr_set[device].insert((void*)kernel);
131+
kernel_attr_set[device].insert((void*) kernel);
132+
cuda_check(cudaPeekAtLastError());
114133
}
115134
void* kernelArgs[] =
116135
{
@@ -128,22 +147,24 @@ int exl3_gemm
128147
};
129148
cudaLaunchCooperativeKernel
130149
(
131-
(void*)kernel,
150+
(void*) kernel,
132151
num_sms,
133152
block_dim,
134153
kernelArgs,
135154
SMEM_MAX,
136155
stream
137156
);
138157
cuda_check(cudaPeekAtLastError());
139-
return selected_shape;
158+
159+
// return selected_shape;
160+
return shape_idx;
140161
}
141162

142163
/*
143164
EXL3 multi matmul, A @ B -> C
144165
145166
- A: row-major A tensor, shape (m, k), dtype float16, contiguous
146-
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*bits), dtype uint16
167+
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*K), dtype uint16
147168
- C: empty row-major C tensor, shape (m, n), dtype float16 or float23, contiguous. Does not need to be zero-initialized
148169
- suh: optional, packed input scales/flips, shape (k//16), dtype float16
149170
- A_had: required if suh given, may be reference to A, temporary storage for input transform, size and dtype as A
@@ -169,7 +190,8 @@ int exl3_mgemm
169190
uint32_t mcg_mult,
170191
uint32_t mul1_mult,
171192
int min_index,
172-
int max_index
193+
int max_index,
194+
int force_num_sms
173195
)
174196
{
175197
const at::cuda::OptionalCUDAGuard device_guard(A.device());
@@ -194,6 +216,7 @@ int exl3_mgemm
194216
int bsz = A.size(1);
195217
int bszm_in = A.size(0);
196218
int bszm_out = C.size(0);
219+
int bszm = MAX(bszm_in, bszm_out);
197220

198221
const long* indices_ptr = (const long*) OPTPTR(indices);
199222
const half* weights_ptr = (const half*) OPTPTR(weights);
@@ -219,8 +242,8 @@ int exl3_mgemm
219242
// Device properties
220243
int device;
221244
cudaGetDevice(&device);
222-
int num_sms = DevCtx::instance().get_num_sms(device);
223-
int total_sms = num_sms;
245+
int total_sms = DevCtx::instance().get_num_sms(device);
246+
int num_sms = force_num_sms ? force_num_sms : total_sms;
224247
int cc = DevCtx::instance().get_cc(device);
225248
int* locks = DevCtx::instance().get_locks(device);
226249

@@ -239,25 +262,44 @@ int exl3_mgemm
239262
if (mcg_mult) { cb = 1; mult = mcg_mult; }
240263
if (mul1_mult) { cb = 2; mult = mul1_mult; }
241264

242-
int selected_shape;
265+
int shape_idx;
243266
int block_dim;
244-
fp_exl3_mgemm_kernel kernel = select_exl3_mgemm_kernel
245-
(
246-
cc, size_m, size_k, size_n, K, c_fp32,
247-
force_shape_idx, &block_dim, &selected_shape,
248-
&num_sms, cb, bszm_in, bszm_out
249-
);
250-
if (!kernel) return 0;
267+
fp_exl3_mgemm_kernel kernel;
268+
int concurrency;
269+
270+
#ifndef NEW_TUNE_MGEMM
271+
kernel = select_exl3_mgemm_kernel
272+
(
273+
cc, size_m, size_k, size_n, K, c_fp32,
274+
force_shape_idx, &block_dim, &shape_idx,
275+
&num_sms, cb, bszm_in, bszm_out
276+
);
277+
if (!kernel) return 0;
278+
concurrency = MIN(total_sms / num_sms, bszm_out);
279+
#else
280+
kernel = select_exl3_mgemm_kernel
281+
(
282+
cc, size_m, size_k, size_n, K, c_fp32,
283+
force_shape_idx, &block_dim, &shape_idx,
284+
&num_sms, cb, bszm_in, bszm_out
285+
);
286+
int tilesize_k = exl3_gemm_tilesize_k_g[shape_idx];
287+
int tilesize_n = exl3_gemm_tilesize_n_g[shape_idx];
288+
int tiles = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
289+
num_sms = tiles;
290+
if (num_sms * bszm > total_sms) num_sms = MAX(total_sms / bszm, 1);
291+
if (num_sms <= total_sms && tiles / num_sms > 48) num_sms = MIN(total_sms, num_sms * 2);
292+
concurrency = MIN(total_sms / num_sms, bszm);
293+
#endif
251294

252295
// Launch bigger grid if possible
253-
int concurrency = MIN(total_sms / num_sms, bszm_out);
254296
dim3 block_grid(num_sms, 1, concurrency);
255297

256298
// Launch
257-
if (kernel_attr_set[device].find((void*)kernel) == kernel_attr_set[device].end())
299+
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
258300
{
259301
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
260-
kernel_attr_set[device].insert((void*)kernel);
302+
kernel_attr_set[device].insert((void*) kernel);
261303
}
262304
void* kernelArgs[] =
263305
{
@@ -279,16 +321,15 @@ int exl3_mgemm
279321
(void*)& min_index,
280322
(void*)& max_index
281323
};
282-
283324
cudaLaunchCooperativeKernel
284325
(
285-
(void*)kernel,
326+
(void*) kernel,
286327
block_grid,
287328
block_dim,
288329
kernelArgs,
289330
SMEM_MAX,
290331
stream
291332
);
292333
cuda_check(cudaPeekAtLastError());
293-
return selected_shape;
334+
return shape_idx;
294335
}

exllamav3/exllamav3_ext/quant/exl3_gemm.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ int exl3_gemm
1212
const c10::optional<at::Tensor>& svh,
1313
int force_shape_idx,
1414
uint32_t mcg_mult,
15-
uint32_t mul1_mult
15+
uint32_t mul1_mult,
16+
int force_num_sms
1617
);
1718

1819
int exl3_mgemm
@@ -30,5 +31,6 @@ int exl3_mgemm
3031
uint32_t mcg_mult,
3132
uint32_t mul1_mult,
3233
int min_index,
33-
int max_index
34+
int max_index,
35+
int force_num_sms
3436
);

0 commit comments

Comments
 (0)