Skip to content

Commit a05fcae

Browse files
author
Italo Nicola
committed
Vulkan: Improve Q8 OUT_PROD performance
Increase OUT_PROD Q8 performance through improving memory locality.
1 parent 29bb438 commit a05fcae

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q8_0.comp

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
#include "generic_binary_head.comp"
55
#include "dequant_funcs.comp"
66

7-
const uint num_threads = 256;
7+
const uint quant_group_sz = 2;
8+
const uint num_threads = 512 / quant_group_sz;
89
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
910

1011
void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) {
@@ -17,38 +18,47 @@ void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uin
1718
}
1819

1920
void main() {
20-
// num_threads * num_iter must equal 512 to match the wg_denoms and get_idx
21-
const uint num_iter = 2;
22-
2321
const uint broadcast2 = uint(p.param2);
2422
const uint broadcast3 = p.param3;
2523

26-
uint idx = get_idx();
24+
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x * quant_group_sz;
25+
26+
uint aoffset = get_aoffset();
27+
uint boffset = get_boffset();
28+
uint doffset = get_doffset();
2729

28-
[[unroll]] for (uint it = 0; it < num_iter; ++it) {
29-
if (idx < p.ne) {
30-
uint i0, i1, i2, i3;
31-
get_dst_indices(idx, i0, i1, i2, i3);
30+
if (idx < p.ne) {
31+
uint i0, i1, i2, i3;
32+
get_dst_indices(idx, i0, i1, i2, i3);
3233

33-
float acc = 0.0f;
34+
vec2 acc = vec2(0.0);
35+
36+
for (uint k = 0; k < p.ne01; k++) {
37+
if (i0 + 1 >= p.ne20) { // XXX
38+
continue;
39+
}
3440

35-
for (uint k = 0; k < p.ne01; k += 1) {
36-
const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
37-
const uint ib = a_block_base + (i0 / QUANT_K) * p.nb00;
38-
const uint iqs = (i0 % QUANT_K) / QUANT_R;
41+
const uint a_block_base = aoffset + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
42+
const uint ib = a_block_base + ((i0) / QUANT_K) * p.nb00;
43+
const uint iqs = ((i0) % QUANT_K) / QUANT_R;
3944

40-
const vec2 v = dequantize(ib, iqs, 0);
41-
const vec2 dm = get_dm(ib, 0);
42-
const float a_val = v.x * dm.x + dm.y;
45+
const vec2 v = dequantize(ib, iqs, 0);
46+
const vec2 dm = get_dm(ib, 0);
47+
const vec2 a_vals = v * dm.x + dm.y;
48+
49+
const uint b_idx = src1_idx(i1, k, i2, i3);
50+
const float b = data_b[boffset + b_idx];
51+
52+
acc += a_vals * b;
53+
}
4354

44-
const uint b_idx = src1_idx(i1, k, i2, i3);
45-
const float b = data_b[get_boffset() + b_idx];
46-
acc += a_val * b;
55+
uint d_idx = dst_idx(i0, i1, i2, i3);
56+
for (uint q = 0; q < quant_group_sz; q++) {
57+
if (d_idx + q + 0 >= p.ne) { // XXX
58+
continue;
4759
}
4860

49-
uint d_idx = dst_idx(i0, i1, i2, i3);
50-
data_d[get_doffset() + d_idx] = acc;
61+
data_d[doffset + d_idx + q] = acc[q];
5162
}
52-
idx += num_threads;
5363
}
5464
}

0 commit comments

Comments
 (0)