44#include "generic_binary_head.comp"
55#include "dequant_funcs.comp"
66
7- const uint num_threads = 256;
7+ const uint quant_group_sz = 2;
8+ const uint num_threads = 512 / quant_group_sz;
89layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
910
1011void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) {
@@ -17,38 +18,47 @@ void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uin
1718}
1819
1920void main() {
20- // num_threads * num_iter must equal 512 to match the wg_denoms and get_idx
21- const uint num_iter = 2;
22-
2321 const uint broadcast2 = uint(p.param2);
2422 const uint broadcast3 = p.param3;
2523
26- uint idx = get_idx();
24+ uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x * quant_group_sz;
25+
26+ uint aoffset = get_aoffset();
27+ uint boffset = get_boffset();
28+ uint doffset = get_doffset();
2729
28- [[unroll]] for (uint it = 0; it < num_iter; ++it) {
29- if (idx < p.ne) {
30- uint i0, i1, i2, i3;
31- get_dst_indices(idx, i0, i1, i2, i3);
30+ if (idx < p.ne) {
31+ uint i0, i1, i2, i3;
32+ get_dst_indices(idx, i0, i1, i2, i3);
3233
33- float acc = 0.0f;
34+ vec2 acc = vec2(0.0);
35+
36+ for (uint k = 0; k < p.ne01; k++) {
37+ if (i0 + 1 >= p.ne20) { // XXX
38+ continue;
39+ }
3440
35- for (uint k = 0; k < p.ne01; k += 1) {
36- const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
37- const uint ib = a_block_base + (i0 / QUANT_K) * p.nb00;
38- const uint iqs = (i0 % QUANT_K) / QUANT_R;
41+ const uint a_block_base = aoffset + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
42+ const uint ib = a_block_base + ((i0) / QUANT_K) * p.nb00;
43+ const uint iqs = ((i0) % QUANT_K) / QUANT_R;
3944
40- const vec2 v = dequantize(ib, iqs, 0);
41- const vec2 dm = get_dm(ib, 0);
42- const float a_val = v.x * dm.x + dm.y;
45+ const vec2 v = dequantize(ib, iqs, 0);
46+ const vec2 dm = get_dm(ib, 0);
47+ const vec2 a_vals = v * dm.x + dm.y;
48+
49+ const uint b_idx = src1_idx(i1, k, i2, i3);
50+ const float b = data_b[boffset + b_idx];
51+
52+ acc += a_vals * b;
53+ }
4354
44- const uint b_idx = src1_idx(i1, k, i2, i3);
45- const float b = data_b[get_boffset() + b_idx];
46- acc += a_val * b;
55+ uint d_idx = dst_idx(i0, i1, i2, i3);
56+ for (uint q = 0; q < quant_group_sz; q++) {
57+ if (d_idx + q + 0 >= p.ne) { // XXX
58+ continue;
4759 }
4860
49- uint d_idx = dst_idx(i0, i1, i2, i3);
50- data_d[get_doffset() + d_idx] = acc;
61+ data_d[doffset + d_idx + q] = acc[q];
5162 }
52- idx += num_threads;
5363 }
5464}
0 commit comments