@@ -12,11 +12,17 @@ namespace cg = cooperative_groups;
1212#include " exl3_devctx.cuh"
1313#include < set>
1414
15+ #define NEW_TUNE_GEMM
16+ #define NEW_TUNE_MGEMM
17+
18+ int exl3_gemm_tilesize_k_g[] = {EXL3_GEMM_TILESIZE_K};
19+ int exl3_gemm_tilesize_n_g[] = {EXL3_GEMM_TILESIZE_N};
20+
1521/*
1622EXL3 matmul, A @ B -> C
1723
1824- A: row-major A tensor, shape (m, k), dtype float16, contiguous
19- - B: EXL3-quantized B tensor, shape (k//16, n//16, 16*bits ), dtype uint16
25+ - B: EXL3-quantized B tensor, shape (k//16, n//16, 16*K ), dtype uint16
2026- C: empty row-major C tensor, shape (m, n), dtype float16 or float32, contiguous. Does not need to be zero-initialized
2127- suh: optional, packed input scales/flips, shape (k//16), dtype float16
2228- A_had: required if suh given, may be reference to A, temporary storage for input transform, size and dtype as A
@@ -39,7 +45,8 @@ int exl3_gemm
3945 const c10::optional<at::Tensor>& svh,
4046 int force_shape_idx,
4147 uint32_t mcg_mult,
42- uint32_t mul1_mult
48+ uint32_t mul1_mult,
49+ int force_num_sms
4350)
4451{
4552 const at::cuda::OptionalCUDAGuard device_guard (A.device ());
@@ -48,7 +55,7 @@ int exl3_gemm
4855 TORCH_CHECK_DIM (B, 3 );
4956 TORCH_CHECK_SHAPES (A, -1 , B, 0 , 16 );
5057 TORCH_CHECK_SHAPES (C, -1 , B, 1 , 16 );
51- // TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
58+ // TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
5259 TORCH_CHECK_DTYPE (A, kHalf );
5360 TORCH_CHECK_DTYPE (B, kShort );
5461 bool c_fp32 = C.dtype () == at::kFloat ;
@@ -59,26 +66,26 @@ int exl3_gemm
5966 half* A_had_ptr = nullptr ;
6067 if (suh_ptr)
6168 {
62- // TORCH_CHECK_SHAPES(suh.value(), 0, A, 1, 1);
69+ // TORCH_CHECK_SHAPES(suh.value(), 0, A, 1, 1);
6370 A_had_ptr = (half*) OPTPTR (A_had);
64- // TORCH_CHECK(A_had_ptr, "Must supply A_had with suh");
65- // TORCH_CHECK_SHAPES_FULL(A_had.value(), A);
71+ // TORCH_CHECK(A_had_ptr, "Must supply A_had with suh");
72+ // TORCH_CHECK_SHAPES_FULL(A_had.value(), A);
6673 }
6774
6875 // Get SV, optionally
6976 const half* svh_ptr = (const half*) OPTPTR (svh);
70- // if (svh_ptr)
71- // TORCH_CHECK_SHAPES(svh.value(), 0, B, 1, 16);
77+ // if (svh_ptr)
78+ // TORCH_CHECK_SHAPES(svh.value(), 0, B, 1, 16);
7279
7380 // Device properties
7481 int device;
7582 cudaGetDevice (&device);
76- int num_sms = DevCtx::instance ().get_num_sms (device);
83+ int num_sms = force_num_sms ? force_num_sms : DevCtx::instance ().get_num_sms (device);
7784 int cc = DevCtx::instance ().get_cc (device);
7885 int * locks = DevCtx::instance ().get_locks (device);
7986
8087 // Dispatch
81- int bits = B.size (2 ) / 16 ;
88+ int K = B.size (2 ) / 16 ;
8289 const half* A_ptr = (const half*) A.data_ptr ();
8390 const uint16_t * B_ptr = (const uint16_t *) B.data_ptr ();
8491 void * C_ptr = (void *) C.data_ptr ();
@@ -96,21 +103,33 @@ int exl3_gemm
96103 if (mcg_mult) { cb = 1 ; mult = mcg_mult; }
97104 if (mul1_mult) { cb = 2 ; mult = mul1_mult; }
98105
99- int selected_shape;
100106 int block_dim;
101- fp_exl3_gemm_kernel kernel = select_exl3_gemm_kernel
102- (
103- cc, size_m, size_k, size_n, bits, c_fp32,
104- force_shape_idx, &block_dim, &selected_shape,
105- &num_sms, cb
106- );
107- if (!kernel) return 0 ;
107+ int shape_idx;
108+ fp_exl3_gemm_kernel kernel;
109+
110+ #ifndef NEW_TUNE_GEMM
111+ kernel = select_exl3_gemm_kernel
112+ (
113+ cc, size_m, size_k, size_n, K, c_fp32,
114+ force_shape_idx, &block_dim, &shape_idx,
115+ &num_sms, cb
116+ );
117+ if (!kernel) return 0 ;
118+ #else
119+ TResult* tr = select_exl3_gemm_mgemm_kernel_new (cc, size_m, size_k, size_n, K, c_fp32, force_shape_idx, force_num_sms, cb);
120+ if (!tr) return 0 ;
121+ num_sms = MIN (num_sms, tr->num_sms );
122+ kernel = tr->kernel ;
123+ block_dim = tr->block_dim ;
124+ shape_idx = tr->shape_idx ;
125+ #endif
108126
109127 // Launch
110- if (kernel_attr_set[device].find ((void *)kernel) == kernel_attr_set[device].end ())
128+ if (kernel_attr_set[device].find ((void *) kernel) == kernel_attr_set[device].end ())
111129 {
112130 cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
113- kernel_attr_set[device].insert ((void *)kernel);
131+ kernel_attr_set[device].insert ((void *) kernel);
132+ cuda_check (cudaPeekAtLastError ());
114133 }
115134 void * kernelArgs[] =
116135 {
@@ -128,22 +147,24 @@ int exl3_gemm
128147 };
129148 cudaLaunchCooperativeKernel
130149 (
131- (void *)kernel,
150+ (void *) kernel,
132151 num_sms,
133152 block_dim,
134153 kernelArgs,
135154 SMEM_MAX,
136155 stream
137156 );
138157 cuda_check (cudaPeekAtLastError ());
139- return selected_shape;
158+
159+ // return selected_shape;
160+ return shape_idx;
140161}
141162
142163/*
143164EXL3 multi matmul, A @ B -> C
144165
145166- A: row-major A tensor, shape (m, k), dtype float16, contiguous
146- - B: EXL3-quantized B tensor, shape (k//16, n//16, 16*bits ), dtype uint16
167+ - B: EXL3-quantized B tensor, shape (k//16, n//16, 16*K ), dtype uint16
147168- C: empty row-major C tensor, shape (m, n), dtype float16 or float23, contiguous. Does not need to be zero-initialized
148169- suh: optional, packed input scales/flips, shape (k//16), dtype float16
149170- A_had: required if suh given, may be reference to A, temporary storage for input transform, size and dtype as A
@@ -169,7 +190,8 @@ int exl3_mgemm
169190 uint32_t mcg_mult,
170191 uint32_t mul1_mult,
171192 int min_index,
172- int max_index
193+ int max_index,
194+ int force_num_sms
173195)
174196{
175197 const at::cuda::OptionalCUDAGuard device_guard (A.device ());
@@ -194,6 +216,7 @@ int exl3_mgemm
194216 int bsz = A.size (1 );
195217 int bszm_in = A.size (0 );
196218 int bszm_out = C.size (0 );
219+ int bszm = MAX (bszm_in, bszm_out);
197220
198221 const long * indices_ptr = (const long *) OPTPTR (indices);
199222 const half* weights_ptr = (const half*) OPTPTR (weights);
@@ -219,8 +242,8 @@ int exl3_mgemm
219242 // Device properties
220243 int device;
221244 cudaGetDevice (&device);
222- int num_sms = DevCtx::instance ().get_num_sms (device);
223- int total_sms = num_sms ;
245+ int total_sms = DevCtx::instance ().get_num_sms (device);
246+ int num_sms = force_num_sms ? force_num_sms : total_sms ;
224247 int cc = DevCtx::instance ().get_cc (device);
225248 int * locks = DevCtx::instance ().get_locks (device);
226249
@@ -239,25 +262,44 @@ int exl3_mgemm
239262 if (mcg_mult) { cb = 1 ; mult = mcg_mult; }
240263 if (mul1_mult) { cb = 2 ; mult = mul1_mult; }
241264
242- int selected_shape ;
265+ int shape_idx ;
243266 int block_dim;
244- fp_exl3_mgemm_kernel kernel = select_exl3_mgemm_kernel
245- (
246- cc, size_m, size_k, size_n, K, c_fp32,
247- force_shape_idx, &block_dim, &selected_shape,
248- &num_sms, cb, bszm_in, bszm_out
249- );
250- if (!kernel) return 0 ;
267+ fp_exl3_mgemm_kernel kernel;
268+ int concurrency;
269+
270+ #ifndef NEW_TUNE_MGEMM
271+ kernel = select_exl3_mgemm_kernel
272+ (
273+ cc, size_m, size_k, size_n, K, c_fp32,
274+ force_shape_idx, &block_dim, &shape_idx,
275+ &num_sms, cb, bszm_in, bszm_out
276+ );
277+ if (!kernel) return 0 ;
278+ concurrency = MIN (total_sms / num_sms, bszm_out);
279+ #else
280+ kernel = select_exl3_mgemm_kernel
281+ (
282+ cc, size_m, size_k, size_n, K, c_fp32,
283+ force_shape_idx, &block_dim, &shape_idx,
284+ &num_sms, cb, bszm_in, bszm_out
285+ );
286+ int tilesize_k = exl3_gemm_tilesize_k_g[shape_idx];
287+ int tilesize_n = exl3_gemm_tilesize_n_g[shape_idx];
288+ int tiles = MAX (size_k / tilesize_k * size_n / tilesize_n, 1 );
289+ num_sms = tiles;
290+ if (num_sms * bszm > total_sms) num_sms = MAX (total_sms / bszm, 1 );
291+ if (num_sms <= total_sms && tiles / num_sms > 48 ) num_sms = MIN (total_sms, num_sms * 2 );
292+ concurrency = MIN (total_sms / num_sms, bszm);
293+ #endif
251294
252295 // Launch bigger grid if possible
253- int concurrency = MIN (total_sms / num_sms, bszm_out);
254296 dim3 block_grid (num_sms, 1 , concurrency);
255297
256298 // Launch
257- if (kernel_attr_set[device].find ((void *)kernel) == kernel_attr_set[device].end ())
299+ if (kernel_attr_set[device].find ((void *) kernel) == kernel_attr_set[device].end ())
258300 {
259301 cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
260- kernel_attr_set[device].insert ((void *)kernel);
302+ kernel_attr_set[device].insert ((void *) kernel);
261303 }
262304 void * kernelArgs[] =
263305 {
@@ -279,16 +321,15 @@ int exl3_mgemm
279321 (void *)& min_index,
280322 (void *)& max_index
281323 };
282-
283324 cudaLaunchCooperativeKernel
284325 (
285- (void *)kernel,
326+ (void *) kernel,
286327 block_grid,
287328 block_dim,
288329 kernelArgs,
289330 SMEM_MAX,
290331 stream
291332 );
292333 cuda_check (cudaPeekAtLastError ());
293- return selected_shape ;
334+ return shape_idx ;
294335}
0 commit comments