|
| 1 | +#version 450 |
| 2 | +#extension GL_EXT_shader_explicit_arithmetic_types : require |
| 3 | +#extension GL_EXT_integer_dot_product : require |
| 4 | + |
| 5 | +#define MMQ |
| 6 | +#define B_TYPE block_q8_1_x4 |
| 7 | + |
| 8 | +#include "mul_mat_vec_base.comp" |
| 9 | +#include "mul_mmq_funcs.comp" |
| 10 | + |
| 11 | +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; |
| 12 | + |
| 13 | +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; |
| 14 | + |
| 15 | +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { |
| 16 | + uint a_offset, b_offset, d_offset; |
| 17 | + get_offsets(a_offset, b_offset, d_offset); |
| 18 | + |
| 19 | + const uint num_blocks_per_row = p.ncols / QUANT_K; |
| 20 | + |
| 21 | + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |
| 22 | + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { |
| 23 | + temp[j][i] = FLOAT_TYPE(0); |
| 24 | + } |
| 25 | + } |
| 26 | + |
| 27 | + const uint tid = gl_LocalInvocationID.x; |
| 28 | + |
| 29 | + for (uint jcol = 0; jcol < NUM_COLS; jcol++) { |
| 30 | + const uint b_base = (jcol * p.batch_stride_b); |
| 31 | + for (uint n = 0; n < num_rows; ++n) { |
| 32 | + const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row; |
| 33 | + FLOAT_TYPE acc = 0.0f; |
| 34 | + for (uint i = tid/8; i < num_blocks_per_row; i+=gl_WorkGroupSize.x/8) { |
| 35 | + const float d = float(data_a[ib0 + i].d); |
| 36 | + [[unroll]] for (uint j = 0; j < 64; j += 32) { |
| 37 | + [[unroll]] for (uint l = 0; l < 4; l+=2) { |
| 38 | + uint k = (tid % 8) * 4; |
| 39 | + |
| 40 | + const uint shift0 = l * 2u; |
| 41 | + const int c00 = int(((data_a[ib0 + i].qs[j + k]) >> shift0) & 3u); |
| 42 | + const int c01 = int(((data_a[ib0 + i].qs[j + k + 1]) >> shift0) & 3u); |
| 43 | + const int c02 = int(((data_a[ib0 + i].qs[j + k + 2]) >> shift0) & 3u); |
| 44 | + const int c03 = int(((data_a[ib0 + i].qs[j + k + 3]) >> shift0) & 3u); |
| 45 | + const int32_t a0_packed = c00 | (c01 << 8) | (c02 << 16) | (c03 << 24); |
| 46 | + const uint b0_idx = i * QUANT_K + j * 4 + l * 32; |
| 47 | + |
| 48 | + const uint shift1 = (l+1) * 2u; |
| 49 | + const int c10 = int(((data_a[ib0 + i].qs[j + k]) >> shift1) & 3u); |
| 50 | + const int c11 = int(((data_a[ib0 + i].qs[j + k + 1]) >> shift1) & 3u); |
| 51 | + const int c12 = int(((data_a[ib0 + i].qs[j + k + 2]) >> shift1) & 3u); |
| 52 | + const int c13 = int(((data_a[ib0 + i].qs[j + k + 3]) >> shift1) & 3u); |
| 53 | + const int32_t a1_packed = c10 | (c11 << 8) | (c12 << 16) | (c13 << 24); |
| 54 | + const uint b1_idx = i * QUANT_K + j * 4 + (l+1) * 32; |
| 55 | + |
| 56 | + // Not checking for OOB since we're guaranteed to be multiple of 256 |
| 57 | + const uint b0_block_idx = b_offset + (b_base + b0_idx) / QUANT_K_Q8_1; |
| 58 | + const uint b1_block_idx = b_offset + (b_base + b1_idx) / QUANT_K_Q8_1; |
| 59 | + const uint b0_block_idx_outer = b0_block_idx / 4; |
| 60 | + const uint b1_block_idx_outer = b1_block_idx / 4; |
| 61 | + const uint b0_block_idx_inner = b0_block_idx % 4; |
| 62 | + const uint b1_block_idx_inner = b1_block_idx % 4; |
| 63 | + vec2 ds0 = vec2(data_b[b_offset + b0_block_idx_outer].ds[b0_block_idx_inner]); |
| 64 | + vec2 ds1 = vec2(data_b[b_offset + b1_block_idx_outer].ds[b1_block_idx_inner]); |
| 65 | + |
| 66 | + const uint vec_idx = k / 4; |
| 67 | + int32_t b0_packed = data_b[b_offset + b0_block_idx_outer].qs[b0_block_idx_inner * 8 + vec_idx]; |
| 68 | + int32_t b1_packed = data_b[b_offset + b1_block_idx_outer].qs[b1_block_idx_inner * 8 + vec_idx]; |
| 69 | + |
| 70 | + int32_t q0_sum = dotPacked4x8EXT(a0_packed, b0_packed); |
| 71 | + int32_t q1_sum = dotPacked4x8EXT(a1_packed, b1_packed); |
| 72 | + acc += ACC_TYPE(d * (FLOAT_TYPE(q0_sum) * ds0.x - FLOAT_TYPE(1.0f / 8) * ds0.y)); |
| 73 | + acc += ACC_TYPE(d * (FLOAT_TYPE(q1_sum) * ds1.x - FLOAT_TYPE(1.0f / 8) * ds1.y)); |
| 74 | + } |
| 75 | + } |
| 76 | + } |
| 77 | + temp[jcol][n] = acc; |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + reduce_result(temp, d_offset, first_row, num_rows, tid); |
| 82 | +} |
| 83 | + |
| 84 | +void main() { |
| 85 | + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); |
| 86 | + |
| 87 | + if (first_row + NUM_ROWS <= p.stride_d) { |
| 88 | + compute_outputs(first_row, NUM_ROWS); |
| 89 | + } else { |
| 90 | + if (first_row >= p.stride_d) { |
| 91 | + return; |
| 92 | + } |
| 93 | + compute_outputs(first_row, p.stride_d - first_row); |
| 94 | + } |
| 95 | +} |
0 commit comments