Skip to content

Commit c4867ed

Browse files
committed
GEMM kernel optimizations
1 parent 4710c01 commit c4867ed

15 files changed

+110
-112
lines changed

exllamav3/exllamav3_ext/quant/codebook.cuh

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ __device__ inline half decode_3inst(uint32_t x)
66
{
77
x *= 89226354u;
88
x += 64248484u;
9-
x &= 0b10001111111111111000111111111111u;
10-
x ^= 0b00111011011000000011101101100000u;
9+
x = (x & 0x8FFF8FFFu) ^ 0x3B603B60u;
10+
// x &= 0b10001111111111111000111111111111u;
11+
// x ^= 0b00111011011000000011101101100000u;
12+
// Compiler doesn't automatically generate LOP3
13+
asm volatile ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x));
1114
half2_uint32 xu(x);
1215
return __hadd(__low2half(xu.as_half2), __high2half(xu.as_half2));
1316
}
@@ -18,10 +21,13 @@ __device__ inline half2 decode_3inst_2(uint32_t x0, uint32_t x1)
1821
x1 *= 89226354u;
1922
x0 += 64248484u;
2023
x1 += 64248484u;
21-
x0 &= 0b10001111111111111000111111111111u;
22-
x1 &= 0b10001111111111111000111111111111u;
23-
x0 ^= 0b00111011011000000011101101100000u;
24-
x1 ^= 0b00111011011000000011101101100000u;
24+
// x0 &= 0b10001111111111111000111111111111u;
25+
// x1 &= 0b10001111111111111000111111111111u;
26+
// x0 ^= 0b00111011011000000011101101100000u;
27+
// x1 ^= 0b00111011011000000011101101100000u;
28+
// Compiler doesn't automatically generate LOP3
29+
asm volatile ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x0));
30+
asm volatile ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x1));
2531
half2_uint32 xu0(x0);
2632
half2_uint32 xu1(x1);
2733
half2 d0 = __halves2half2(__low2half(xu0.as_half2), __low2half(xu1.as_half2));

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_1.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,15 @@ namespace cg = cooperative_groups;
77
#include "../../util.cuh"
88
#include "../../ptx.cuh"
99
#include "../exl3_gemm_kernel.cuh"
10+
#include "../exl3_gemv_kernel.cuh"
1011
#include "exl3_comp_unit_1.cuh"
1112

1213
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b1[] = {
13-
nullptr,
14-
exl3_gemm_kernel<1, true, EXL3_GEMM_SHAPE_1>,
15-
exl3_gemm_kernel<1, true, EXL3_GEMM_SHAPE_2>,
16-
exl3_gemm_kernel<1, true, EXL3_GEMM_SHAPE_3>,
17-
exl3_gemm_kernel<1, true, EXL3_GEMM_SHAPE_4>
14+
EXL3_GEMM_KERNEL_INSTANCES(1, true)
1815
};
1916

2017
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b1[] = {
21-
nullptr,
22-
exl3_gemm_kernel<1, false, EXL3_GEMM_SHAPE_1>,
23-
exl3_gemm_kernel<1, false, EXL3_GEMM_SHAPE_2>,
24-
exl3_gemm_kernel<1, false, EXL3_GEMM_SHAPE_3>,
25-
exl3_gemm_kernel<1, false, EXL3_GEMM_SHAPE_4>
18+
EXL3_GEMM_KERNEL_INSTANCES(1, false)
2619
};
2720

2821

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_2.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,15 @@ namespace cg = cooperative_groups;
77
#include "../../util.cuh"
88
#include "../../ptx.cuh"
99
#include "../exl3_gemm_kernel.cuh"
10+
#include "../exl3_gemv_kernel.cuh"
1011
#include "exl3_comp_unit_2.cuh"
1112

1213
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b2[] = {
13-
nullptr,
14-
exl3_gemm_kernel<2, true, EXL3_GEMM_SHAPE_1>,
15-
exl3_gemm_kernel<2, true, EXL3_GEMM_SHAPE_2>,
16-
exl3_gemm_kernel<2, true, EXL3_GEMM_SHAPE_3>,
17-
exl3_gemm_kernel<2, true, EXL3_GEMM_SHAPE_4>
14+
EXL3_GEMM_KERNEL_INSTANCES(2, true)
1815
};
1916

2017
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b2[] = {
21-
nullptr,
22-
exl3_gemm_kernel<2, false, EXL3_GEMM_SHAPE_1>,
23-
exl3_gemm_kernel<2, false, EXL3_GEMM_SHAPE_2>,
24-
exl3_gemm_kernel<2, false, EXL3_GEMM_SHAPE_3>,
25-
exl3_gemm_kernel<2, false, EXL3_GEMM_SHAPE_4>
18+
EXL3_GEMM_KERNEL_INSTANCES(2, false)
2619
};
2720

2821

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_3.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,15 @@ namespace cg = cooperative_groups;
77
#include "../../util.cuh"
88
#include "../../ptx.cuh"
99
#include "../exl3_gemm_kernel.cuh"
10+
#include "../exl3_gemv_kernel.cuh"
1011
#include "exl3_comp_unit_3.cuh"
1112

1213
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b3[] = {
13-
nullptr,
14-
exl3_gemm_kernel<3, true, EXL3_GEMM_SHAPE_1>,
15-
exl3_gemm_kernel<3, true, EXL3_GEMM_SHAPE_2>,
16-
exl3_gemm_kernel<3, true, EXL3_GEMM_SHAPE_3>,
17-
exl3_gemm_kernel<3, true, EXL3_GEMM_SHAPE_4>
14+
EXL3_GEMM_KERNEL_INSTANCES(3, true)
1815
};
1916

2017
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b3[] = {
21-
nullptr,
22-
exl3_gemm_kernel<3, false, EXL3_GEMM_SHAPE_1>,
23-
exl3_gemm_kernel<3, false, EXL3_GEMM_SHAPE_2>,
24-
exl3_gemm_kernel<3, false, EXL3_GEMM_SHAPE_3>,
25-
exl3_gemm_kernel<3, false, EXL3_GEMM_SHAPE_4>
18+
EXL3_GEMM_KERNEL_INSTANCES(3, false)
2619
};
2720

2821

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_4.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,15 @@ namespace cg = cooperative_groups;
77
#include "../../util.cuh"
88
#include "../../ptx.cuh"
99
#include "../exl3_gemm_kernel.cuh"
10+
#include "../exl3_gemv_kernel.cuh"
1011
#include "exl3_comp_unit_4.cuh"
1112

1213
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b4[] = {
13-
nullptr,
14-
exl3_gemm_kernel<4, true, EXL3_GEMM_SHAPE_1>,
15-
exl3_gemm_kernel<4, true, EXL3_GEMM_SHAPE_2>,
16-
exl3_gemm_kernel<4, true, EXL3_GEMM_SHAPE_3>,
17-
exl3_gemm_kernel<4, true, EXL3_GEMM_SHAPE_4>
14+
EXL3_GEMM_KERNEL_INSTANCES(4, true)
1815
};
1916

2017
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b4[] = {
21-
nullptr,
22-
exl3_gemm_kernel<4, false, EXL3_GEMM_SHAPE_1>,
23-
exl3_gemm_kernel<4, false, EXL3_GEMM_SHAPE_2>,
24-
exl3_gemm_kernel<4, false, EXL3_GEMM_SHAPE_3>,
25-
exl3_gemm_kernel<4, false, EXL3_GEMM_SHAPE_4>
18+
EXL3_GEMM_KERNEL_INSTANCES(4, false)
2619
};
2720

2821

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_5.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,15 @@ namespace cg = cooperative_groups;
77
#include "../../util.cuh"
88
#include "../../ptx.cuh"
99
#include "../exl3_gemm_kernel.cuh"
10+
#include "../exl3_gemv_kernel.cuh"
1011
#include "exl3_comp_unit_5.cuh"
1112

1213
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b5[] = {
13-
nullptr,
14-
exl3_gemm_kernel<5, true, EXL3_GEMM_SHAPE_1>,
15-
exl3_gemm_kernel<5, true, EXL3_GEMM_SHAPE_2>,
16-
exl3_gemm_kernel<5, true, EXL3_GEMM_SHAPE_3>,
17-
exl3_gemm_kernel<5, true, EXL3_GEMM_SHAPE_4>
14+
EXL3_GEMM_KERNEL_INSTANCES(5, true)
1815
};
1916

2017
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b5[] = {
21-
nullptr,
22-
exl3_gemm_kernel<5, false, EXL3_GEMM_SHAPE_1>,
23-
exl3_gemm_kernel<5, false, EXL3_GEMM_SHAPE_2>,
24-
exl3_gemm_kernel<5, false, EXL3_GEMM_SHAPE_3>,
25-
exl3_gemm_kernel<5, false, EXL3_GEMM_SHAPE_4>
18+
EXL3_GEMM_KERNEL_INSTANCES(5, false)
2619
};
2720

2821

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_6.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,13 @@ namespace cg = cooperative_groups;
77
#include "../../util.cuh"
88
#include "../../ptx.cuh"
99
#include "../exl3_gemm_kernel.cuh"
10+
#include "../exl3_gemv_kernel.cuh"
1011
#include "exl3_comp_unit_6.cuh"
1112

1213
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b6[] = {
13-
nullptr,
14-
exl3_gemm_kernel<6, true, EXL3_GEMM_SHAPE_1>,
15-
exl3_gemm_kernel<6, true, EXL3_GEMM_SHAPE_2>,
16-
exl3_gemm_kernel<6, true, EXL3_GEMM_SHAPE_3>,
17-
exl3_gemm_kernel<6, true, EXL3_GEMM_SHAPE_4>
14+
EXL3_GEMM_KERNEL_INSTANCES(6, true)
1815
};
1916

2017
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b6[] = {
21-
nullptr,
22-
exl3_gemm_kernel<6, false, EXL3_GEMM_SHAPE_1>,
23-
exl3_gemm_kernel<6, false, EXL3_GEMM_SHAPE_2>,
24-
exl3_gemm_kernel<6, false, EXL3_GEMM_SHAPE_3>,
25-
exl3_gemm_kernel<6, false, EXL3_GEMM_SHAPE_4>
18+
EXL3_GEMM_KERNEL_INSTANCES(6, false)
2619
};

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_7.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,13 @@ namespace cg = cooperative_groups;
77
#include "../../util.cuh"
88
#include "../../ptx.cuh"
99
#include "../exl3_gemm_kernel.cuh"
10+
#include "../exl3_gemv_kernel.cuh"
1011
#include "exl3_comp_unit_7.cuh"
1112

1213
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b7[] = {
13-
nullptr,
14-
exl3_gemm_kernel<7, true, EXL3_GEMM_SHAPE_1>,
15-
exl3_gemm_kernel<7, true, EXL3_GEMM_SHAPE_2>,
16-
exl3_gemm_kernel<7, true, EXL3_GEMM_SHAPE_3>,
17-
exl3_gemm_kernel<7, true, EXL3_GEMM_SHAPE_4>
14+
EXL3_GEMM_KERNEL_INSTANCES(7, true)
1815
};
1916

2017
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b7[] = {
21-
nullptr,
22-
exl3_gemm_kernel<7, false, EXL3_GEMM_SHAPE_1>,
23-
exl3_gemm_kernel<7, false, EXL3_GEMM_SHAPE_2>,
24-
exl3_gemm_kernel<7, false, EXL3_GEMM_SHAPE_3>,
25-
exl3_gemm_kernel<7, false, EXL3_GEMM_SHAPE_4>
18+
EXL3_GEMM_KERNEL_INSTANCES(7, false)
2619
};

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_8.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,13 @@ namespace cg = cooperative_groups;
77
#include "../../util.cuh"
88
#include "../../ptx.cuh"
99
#include "../exl3_gemm_kernel.cuh"
10+
#include "../exl3_gemv_kernel.cuh"
1011
#include "exl3_comp_unit_8.cuh"
1112

1213
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b8[] = {
13-
nullptr,
14-
exl3_gemm_kernel<8, true, EXL3_GEMM_SHAPE_1>,
15-
exl3_gemm_kernel<8, true, EXL3_GEMM_SHAPE_2>,
16-
exl3_gemm_kernel<8, true, EXL3_GEMM_SHAPE_3>,
17-
exl3_gemm_kernel<8, true, EXL3_GEMM_SHAPE_4>
14+
EXL3_GEMM_KERNEL_INSTANCES(8, true)
1815
};
1916

2017
fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b8[] = {
21-
nullptr,
22-
exl3_gemm_kernel<8, false, EXL3_GEMM_SHAPE_1>,
23-
exl3_gemm_kernel<8, false, EXL3_GEMM_SHAPE_2>,
24-
exl3_gemm_kernel<8, false, EXL3_GEMM_SHAPE_3>,
25-
exl3_gemm_kernel<8, false, EXL3_GEMM_SHAPE_4>
18+
EXL3_GEMM_KERNEL_INSTANCES(8, false)
2619
};

exllamav3/exllamav3_ext/quant/exl3_dq.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
__device__ __forceinline__ uint32_t fshift(uint32_t b, uint32_t a, int shift)
66
{
7-
// uint64_t merged = ((uint64_t)b << 32) | (uint64_t) a;
8-
// return (uint32_t)(merged >> shift);
7+
uint64_t merged = ((uint64_t)a << 32) | (uint64_t) b;
8+
return (uint32_t)(merged >> shift);
99

10-
// Conditional funnel shift is somehow faster
11-
if (shift < 32) return __funnelshift_r(b, a, shift);
12-
return a >> (shift - 32);
10+
// Conditional funnel shift is somehow no longer faster
11+
// if (shift < 32) return __funnelshift_r(b, a, shift);
12+
// return a >> (shift - 32);
1313
}
1414

1515
template <int bits>

0 commit comments

Comments
 (0)