Skip to content

Commit afd0e47

Browse files
author
Italo Nicola
committed
Vulkan: OUT_PROD performance improvements
1 parent fdcc1ae commit afd0e47

File tree

3 files changed

+49
-40
lines changed

3 files changed

+49
-40
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q4_0.comp

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,41 +17,39 @@ void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uin
1717
}
1818

1919
void main() {
20-
// num_threads * num_iter must equal 512 to match the wg_denoms and get_idx
21-
const uint num_iter = 2;
22-
2320
const uint broadcast2 = uint(p.param2);
2421
const uint broadcast3 = p.param3;
2522

26-
uint idx = get_idx();
23+
uint idx_base = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512;
24+
uint idx = idx_base + (gl_GlobalInvocationID.x / 16) * 16 + gl_GlobalInvocationID.x;
25+
26+
if (idx < p.ne) {
27+
uint i0, i1, i2, i3;
28+
get_dst_indices(idx, i0, i1, i2, i3);
2729

28-
[[unroll]] for (uint it = 0; it < num_iter; ++it) {
29-
if (idx < p.ne) {
30-
uint i0, i1, i2, i3;
31-
get_dst_indices(idx, i0, i1, i2, i3);
30+
vec2 acc = vec2(0.0f);
3231

33-
float acc = 0.0f;
32+
if (i0 + 16 < p.ne20) { // XXX
33+
[[unroll]] for (uint k = 0; k < p.ne01; k += 1) {
3434

35-
for (uint k = 0; k < p.ne01; k += 1) {
3635
const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
3736
const uint ib = a_block_base + (i0 / QUANT_K);
3837
const uint iqs = i0 % (QUANT_K / QUANT_R);
39-
const uint upper = (i0 % QUANT_K) / (QUANT_K / QUANT_R);
40-
const uint lower = 1 - upper;
4138

4239
const vec2 v = dequantize(ib, iqs, 0);
4340
const vec2 dm = get_dm(ib, 0);
4441

45-
const float a_val = (v.x * lower + v.y * upper) * dm.x + dm.y;
42+
const vec2 a_vals = v * dm.x + dm.y;
4643

4744
const uint b_idx = src1_idx(i1, k, i2, i3);
4845
const float b = data_b[get_boffset() + b_idx];
49-
acc += a_val * b;
46+
acc += a_vals * b;
5047
}
5148

52-
uint d_idx = dst_idx(i0, i1, i2, i3);
53-
data_d[get_doffset() + d_idx] = acc;
49+
uint d_idx_0 = dst_idx(i0, i1, i2, i3);
50+
data_d[get_doffset() + d_idx_0] = acc.x;
51+
uint d_idx_1 = dst_idx(i0 + 16, i1, i2, i3);
52+
data_d[get_doffset() + d_idx_1] = acc.y;
5453
}
55-
idx += num_threads;
5654
}
5755
}

ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q8_0.comp

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
#include "generic_binary_head.comp"
55
#include "dequant_funcs.comp"
66

7-
const uint quant_group_sz = 2;
8-
const uint num_threads = 512 / quant_group_sz;
7+
const uint num_threads = 256;
98
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
109

1110
void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) {
@@ -21,7 +20,7 @@ void main() {
2120
const uint broadcast2 = uint(p.param2);
2221
const uint broadcast3 = p.param3;
2322

24-
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x * quant_group_sz;
23+
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x * 2;
2524

2625
uint aoffset = get_aoffset();
2726
uint boffset = get_boffset();
@@ -33,32 +32,27 @@ void main() {
3332

3433
vec2 acc = vec2(0.0);
3534

36-
for (uint k = 0; k < p.ne01; k++) {
37-
if (i0 + 1 >= p.ne20) {
38-
continue;
39-
}
40-
41-
const uint a_block_base = aoffset + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
42-
const uint ib = a_block_base + ((i0) / QUANT_K) * p.nb00;
43-
const uint iqs = ((i0) % QUANT_K) / QUANT_R;
35+
if (i0 + 1 < p.ne20) {
36+
[[unroll]] for (uint k = 0; k < p.ne01; k++) {
4437

45-
const vec2 v = dequantize(ib, iqs, 0);
46-
const vec2 dm = get_dm(ib, 0);
47-
const vec2 a_vals = v * dm.x + dm.y;
38+
const uint a_block_base = aoffset + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
39+
const uint ib = a_block_base + ((i0) / QUANT_K) * p.nb00;
40+
const uint iqs = ((i0) % QUANT_K) / QUANT_R;
4841

49-
const uint b_idx = src1_idx(i1, k, i2, i3);
50-
const float b = data_b[boffset + b_idx];
42+
const vec2 v = dequantize(ib, iqs, 0);
43+
const vec2 dm = get_dm(ib, 0);
44+
const vec2 a_vals = v * dm.x + dm.y;
5145

52-
acc += a_vals * b;
53-
}
46+
const uint b_idx = src1_idx(i1, k, i2, i3);
47+
const float b = data_b[boffset + b_idx];
5448

55-
uint d_idx = dst_idx(i0, i1, i2, i3);
56-
for (uint q = 0; q < quant_group_sz; q++) {
57-
if (d_idx + q >= p.ne) {
58-
continue;
49+
acc += a_vals * b;
5950
}
6051

61-
data_d[doffset + d_idx + q] = acc[q];
52+
uint d_idx_0 = dst_idx(i0, i1, i2, i3);
53+
data_d[doffset + d_idx_0] = acc.x;
54+
uint d_idx_1 = dst_idx(i0 + 1, i1, i2, i3);
55+
data_d[doffset + d_idx_1] = acc.y;
6256
}
6357
}
6458
}

tests/test-backend-ops.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6169,6 +6169,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
61696169
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
61706170
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
61716171

6172+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q4_0, GGML_TYPE_F32, 32, 1, 32, {1, 1}, {1, 1}));
6173+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q4_0, GGML_TYPE_F32, 32, 2, 32, {1, 1}, {1, 1}));
6174+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q4_0, GGML_TYPE_F32, 32, 4, 32, {1, 1}, {1, 1}));
6175+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q4_0, GGML_TYPE_F32, 32, 8, 32, {1, 1}, {1, 1}));
6176+
6177+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q8_0, GGML_TYPE_F32, 32, 1, 32, {1, 1}, {1, 1}));
6178+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q8_0, GGML_TYPE_F32, 32, 2, 32, {1, 1}, {1, 1}));
6179+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q8_0, GGML_TYPE_F32, 32, 4, 32, {1, 1}, {1, 1}));
6180+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q8_0, GGML_TYPE_F32, 32, 8, 32, {1, 1}, {1, 1}));
6181+
6182+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q6_K, GGML_TYPE_F32, 32, 1, 32, {1, 1}, {1, 1}));
6183+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q6_K, GGML_TYPE_F32, 32, 2, 32, {1, 1}, {1, 1}));
6184+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q6_K, GGML_TYPE_F32, 32, 4, 32, {1, 1}, {1, 1}));
6185+
test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q6_K, GGML_TYPE_F32, 32, 8, 32, {1, 1}, {1, 1}));
6186+
6187+
6188+
61726189
#if 0 // these tests are disabled due to high memory usage and long runtime, they can fail on some backends
61736190
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 4096*40, 256, 1024, { 1, 1}, {1, 1}));
61746191
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 4096*40, 256, 1024, { 1, 1}, {1, 1}));

0 commit comments

Comments
 (0)