Skip to content

Commit 417679c

Browse files
author
Italo Nicola
committed
Vulkan: TQ2_0 x Q8_1 MUL_MAT perf improvements
1 parent 9271a97 commit 417679c

File tree

10 files changed

+228
-24
lines changed

10 files changed

+228
-24
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3148,6 +3148,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
31483148
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
31493149
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
31503150

3151+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_q8_1_f32", arr_dmmv_tq2_0_q8_1_f32_len[reduc], arr_dmmv_tq2_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
31513152
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
31523153
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
31533154
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
@@ -4829,6 +4830,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
48294830

48304831
if (b_type == GGML_TYPE_Q8_1) {
48314832
switch (a_type) {
4833+
case GGML_TYPE_TQ2_0:
48324834
case GGML_TYPE_Q4_0:
48334835
case GGML_TYPE_Q4_1:
48344836
case GGML_TYPE_Q5_0:
@@ -4891,6 +4893,9 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
48914893
if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
48924894
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
48934895
}
4896+
if (ctx->device->vendor_id == VK_VENDOR_ID_QUALCOMM) {
4897+
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
4898+
}
48944899
return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1];
48954900
}
48964901

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,8 +664,12 @@ float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords
664664
const float16_t d = bl.block.d;
665665
const uint idx = coordInBlock[1];
666666

667-
const uint byte_idx = ((idx >> 7) << 5) + (idx & 31u);
668-
const uint qsshift = (((idx & 127u) >> 5) << 1);
667+
const uint iqs = idx % 128u;
668+
const uint upper = idx / 128u;
669+
670+
const uint byte_idx = (upper * 32u) + (iqs % 32u);
671+
672+
const uint qsshift = (iqs / 32u) * 2u;
669673

670674
const uint c = (uint(bl.block.qs[byte_idx]) >> qsshift) & 3u;
671675
return d * float16_t(float(c) - 1.0f);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
#include "types.comp"
1515

16-
#ifndef MMQ
16+
#if !defined(MMQ) || !defined(A_TYPE_PACKED16)
1717
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1818
#else
1919
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2121
}
2222
}
2323

24-
[[unroll]] for (uint i = tid; i < num_blocks_per_row; i += gl_WorkGroupSize.x) {
25-
26-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
27-
const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row;
28-
const float d = float(data_a[ib0 + i].d);
29-
30-
[[unroll]] for (uint j = 0; j < 64; j += 32) {
31-
[[unroll]] for (uint l = 0; l < 4; ++l) {
32-
[[unroll]] for (uint k = 0; k < 32; ++k) {
24+
for (uint n = 0; n < num_rows; ++n) {
25+
const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row;
26+
for (uint jcol = 0; jcol < NUM_COLS; ++jcol) {
27+
const uint b_base = (jcol * p.batch_stride_b);
28+
FLOAT_TYPE acc = 0.0f;
29+
for (uint i = tid/8; i < num_blocks_per_row; i += gl_WorkGroupSize.x/8) {
30+
const FLOAT_TYPE d = float(data_a[ib0 + i].d);
31+
32+
[[unroll]] for (uint j = 0; j < 64; j += 32) {
33+
[[unroll]] for (uint l = 0; l < 4; ++l) {
34+
[[unroll]] for (uint k = tid%8; k < 32; k+=8) {
35+
//uint k = (tid % 8) * 4;
3336
// Extract quantized value: ((x[i].qs[j + k] >> (l*2)) & 3) - 1
3437
const uint q_byte = uint(data_a[ib0 + i].qs[j + k]);
3538
const uint shift = l * 2;
@@ -38,10 +41,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
3841

3942
// y-data access pattern: y[i].qs[j*4 + l*32 + k]
4043
const uint b_idx = i * QUANT_K + j * 4 + l * 32 + k;
41-
if (b_idx < p.ncols) {
42-
[[unroll]] for (uint jcol = 0; jcol < NUM_COLS; ++jcol) {
43-
temp[jcol][n] += dequant_val * FLOAT_TYPE(data_b[jcol * p.batch_stride_b + b_offset + b_idx]);
44-
}
44+
temp[jcol][n] += dequant_val * FLOAT_TYPE(data_b[b_base + b_offset + b_idx]);
4545
}
4646
}
4747
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#version 450
2+
#extension GL_EXT_shader_explicit_arithmetic_types : require
3+
#extension GL_EXT_integer_dot_product : require
4+
5+
#define MMQ
6+
#define B_TYPE block_q8_1_x4
7+
8+
#include "mul_mat_vec_base.comp"
9+
#include "mul_mmq_funcs.comp"
10+
11+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
12+
13+
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
14+
15+
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
16+
uint a_offset, b_offset, d_offset;
17+
get_offsets(a_offset, b_offset, d_offset);
18+
19+
const uint num_blocks_per_row = p.ncols / QUANT_K;
20+
21+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
22+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
23+
temp[j][i] = FLOAT_TYPE(0);
24+
}
25+
}
26+
27+
const uint tid = gl_LocalInvocationID.x;
28+
29+
for (uint jcol = 0; jcol < NUM_COLS; jcol++) {
30+
const uint b_base = (jcol * p.batch_stride_b);
31+
for (uint n = 0; n < num_rows; ++n) {
32+
const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row;
33+
FLOAT_TYPE acc = 0.0f;
34+
for (uint i = tid/8; i < num_blocks_per_row; i+=gl_WorkGroupSize.x/8) {
35+
const float d = float(data_a[ib0 + i].d);
36+
[[unroll]] for (uint j = 0; j < 64; j += 32) {
37+
[[unroll]] for (uint l = 0; l < 4; l+=2) {
38+
uint k = (tid % 8) * 4;
39+
40+
const uint shift0 = l * 2u;
41+
const int c00 = int(((data_a[ib0 + i].qs[j + k]) >> shift0) & 3u);
42+
const int c01 = int(((data_a[ib0 + i].qs[j + k + 1]) >> shift0) & 3u);
43+
const int c02 = int(((data_a[ib0 + i].qs[j + k + 2]) >> shift0) & 3u);
44+
const int c03 = int(((data_a[ib0 + i].qs[j + k + 3]) >> shift0) & 3u);
45+
const int32_t a0_packed = c00 | (c01 << 8) | (c02 << 16) | (c03 << 24);
46+
const uint b0_idx = i * QUANT_K + j * 4 + l * 32;
47+
48+
const uint shift1 = (l+1) * 2u;
49+
const int c10 = int(((data_a[ib0 + i].qs[j + k]) >> shift1) & 3u);
50+
const int c11 = int(((data_a[ib0 + i].qs[j + k + 1]) >> shift1) & 3u);
51+
const int c12 = int(((data_a[ib0 + i].qs[j + k + 2]) >> shift1) & 3u);
52+
const int c13 = int(((data_a[ib0 + i].qs[j + k + 3]) >> shift1) & 3u);
53+
const int32_t a1_packed = c10 | (c11 << 8) | (c12 << 16) | (c13 << 24);
54+
const uint b1_idx = i * QUANT_K + j * 4 + (l+1) * 32;
55+
56+
// Not checking for OOB since we're guaranteed to be multiple of 256
57+
const uint b0_block_idx = b_offset + (b_base + b0_idx) / QUANT_K_Q8_1;
58+
const uint b1_block_idx = b_offset + (b_base + b1_idx) / QUANT_K_Q8_1;
59+
const uint b0_block_idx_outer = b0_block_idx / 4;
60+
const uint b1_block_idx_outer = b1_block_idx / 4;
61+
const uint b0_block_idx_inner = b0_block_idx % 4;
62+
const uint b1_block_idx_inner = b1_block_idx % 4;
63+
vec2 ds0 = vec2(data_b[b_offset + b0_block_idx_outer].ds[b0_block_idx_inner]);
64+
vec2 ds1 = vec2(data_b[b_offset + b1_block_idx_outer].ds[b1_block_idx_inner]);
65+
66+
const uint vec_idx = k / 4;
67+
int32_t b0_packed = data_b[b_offset + b0_block_idx_outer].qs[b0_block_idx_inner * 8 + vec_idx];
68+
int32_t b1_packed = data_b[b_offset + b1_block_idx_outer].qs[b1_block_idx_inner * 8 + vec_idx];
69+
70+
int32_t q0_sum = dotPacked4x8EXT(a0_packed, b0_packed);
71+
int32_t q1_sum = dotPacked4x8EXT(a1_packed, b1_packed);
72+
acc += ACC_TYPE(d * (FLOAT_TYPE(q0_sum) * ds0.x - FLOAT_TYPE(1.0f / 8) * ds0.y));
73+
acc += ACC_TYPE(d * (FLOAT_TYPE(q1_sum) * ds1.x - FLOAT_TYPE(1.0f / 8) * ds1.y));
74+
}
75+
}
76+
}
77+
temp[jcol][n] = acc;
78+
}
79+
}
80+
81+
reduce_result(temp, d_offset, first_row, num_rows, tid);
82+
}
83+
84+
void main() {
85+
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
86+
87+
if (first_row + NUM_ROWS <= p.stride_d) {
88+
compute_outputs(first_row, NUM_ROWS);
89+
} else {
90+
if (first_row >= p.stride_d) {
91+
return;
92+
}
93+
compute_outputs(first_row, p.stride_d - first_row);
94+
}
95+
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
3030
const uint b_block_idx_inner = b_block_idx % 4;
3131
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
3232

33-
#if QUANT_R == 2
33+
#if QUANT_R == 2 || QUANT_R == 4
3434
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
3535
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
3636
#else
@@ -40,12 +40,19 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
4040

4141
uint ibi = first_row*p.ncols;
4242
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
43-
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
44-
ibi += p.ncols;
43+
const uint cur_ibi = ibi;
44+
const uint a_block_idx = (cur_ibi + col)/QUANT_K + a_offset;
4545

4646
int32_t q_sum = 0;
47-
#if QUANT_R == 2
47+
#if QUANT_R == 2 || QUANT_R == 4
48+
#if defined(DATA_A_TQ2_0)
49+
// For TQ2_0 (QUANT_K=256), repack needs the within-block K base to select
50+
// the correct half (k<128 or k>=128) and 32-wide quarter. Pass k_base+b_qs_idx.
51+
const uint k_base = (cur_ibi + col) % QUANT_K;
52+
const i32vec2 data_a_qs = repack(a_block_idx, k_base + b_qs_idx);
53+
#else
4854
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
55+
#endif
4956
q_sum += dotPacked4x8EXT(data_a_qs.x,
5057
cache_b_qs[0]);
5158
q_sum += dotPacked4x8EXT(data_a_qs.y,
@@ -59,11 +66,14 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
5966
cache_b_qs[1]);
6067
#endif
6168

62-
#if QUANT_AUXF == 1
69+
#if QUANT_AUXF == 1 && QUANT_R <= 2
70+
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
71+
#elif QUANT_AUXF == 1 && QUANT_R == 4
6372
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
6473
#else
6574
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
6675
#endif
76+
ibi += p.ncols;
6777
}
6878
}
6979
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,82 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int
8686
}
8787
#endif
8888

89-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
89+
#if defined(DATA_A_TQ2_0)
90+
i32vec2 repack(uint ib, uint iqs) {
91+
const uint k00 = iqs + 0u;
92+
const uint ip00 = ((k00 >> 7) & 1u) * 32u;
93+
const uint b00 = (k00 & 31u) + ip00;
94+
const uint s00 = ((k00 >> 5) & 3u) * 2u;
95+
96+
const uint k01 = iqs + 1u;
97+
const uint ip01 = ((k01 >> 7) & 1u) * 32u;
98+
const uint b01 = (k01 & 31u) + ip01;
99+
const uint s01 = ((k01 >> 5) & 3u) * 2u;
100+
101+
const uint k02 = iqs + 2u;
102+
const uint ip02 = ((k02 >> 7) & 1u) * 32u;
103+
const uint b02 = (k02 & 31u) + ip02;
104+
const uint s02 = ((k02 >> 5) & 3u) * 2u;
105+
106+
const uint k03 = iqs + 3u;
107+
const uint ip03 = ((k03 >> 7) & 1u) * 32u;
108+
const uint b03 = (k03 & 31u) + ip03;
109+
const uint s03 = ((k03 >> 5) & 3u) * 2u;
110+
111+
const int q00 = int(data_a[ib].qs[b00]);
112+
const int q01 = int(data_a[ib].qs[b01]);
113+
const int q02 = int(data_a[ib].qs[b02]);
114+
const int q03 = int(data_a[ib].qs[b03]);
115+
116+
const int t00 = (q00 >> int(s00)) & 3;
117+
const int t01 = (q01 >> int(s01)) & 3;
118+
const int t02 = (q02 >> int(s02)) & 3;
119+
const int t03 = (q03 >> int(s03)) & 3;
120+
121+
const int v0 = (t00 & 0xFF) | ((t01 & 0xFF) << 8) | ((t02 & 0xFF) << 16) | ((t03 & 0xFF) << 24);
122+
123+
124+
const uint k10 = iqs + 16u + 0u;
125+
const uint ip10 = ((k10 >> 7) & 1u) * 32u;
126+
const uint b10 = (k10 & 31u) + ip10;
127+
const uint s10 = ((k10 >> 5) & 3u) * 2u;
128+
129+
const uint k11 = iqs + 16u + 1u;
130+
const uint ip11 = ((k11 >> 7) & 1u) * 32u;
131+
const uint b11 = (k11 & 31u) + ip11;
132+
const uint s11 = ((k11 >> 5) & 3u) * 2u;
133+
134+
const uint k12 = iqs + 16u + 2u;
135+
const uint ip12 = ((k12 >> 7) & 1u) * 32u;
136+
const uint b12 = (k12 & 31u) + ip12;
137+
const uint s12 = ((k12 >> 5) & 3u) * 2u;
138+
139+
const uint k13 = iqs + 16u + 3u;
140+
const uint ip13 = ((k13 >> 7) & 1u) * 32u;
141+
const uint b13 = (k13 & 31u) + ip13;
142+
const uint s13 = ((k13 >> 5) & 3u) * 2u;
143+
144+
const int q10 = int(data_a[ib].qs[b10]);
145+
const int q11 = int(data_a[ib].qs[b11]);
146+
const int q12 = int(data_a[ib].qs[b12]);
147+
const int q13 = int(data_a[ib].qs[b13]);
148+
149+
const int u10 = (q10 >> int(s10)) & 3;
150+
const int u11 = (q11 >> int(s11)) & 3;
151+
const int u12 = (q12 >> int(s12)) & 3;
152+
const int u13 = (q13 >> int(s13)) & 3;
153+
154+
const int v1 = (u10 & 0xFF) | ((u11 & 0xFF) << 8) | ((u12 & 0xFF) << 16) | ((u13 & 0xFF) << 24);
155+
156+
return i32vec2(v0, v1);
157+
}
158+
159+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
160+
return ACC_TYPE(da * (float(q_sum) * dsb.x - float(1.0f / sum_divisor) * dsb.y));
161+
}
162+
#endif
163+
164+
#if defined(DATA_A_TQ2_0) || defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
90165
FLOAT_TYPE get_d(uint ib) {
91166
return FLOAT_TYPE(data_a[ib].d);
92167
}

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,7 @@ struct block_tq2_0
13681368
#if defined(DATA_A_TQ2_0)
13691369
#define QUANT_K QUANT_K_TQ2_0
13701370
#define QUANT_R QUANT_R_TQ2_0
1371+
#define QUANT_AUXF 1
13711372
#define A_TYPE block_tq2_0
13721373
#endif
13731374

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,10 @@ void process_shaders() {
532532
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
533533
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
534534
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
535+
} else if (tname == "tq2_0") {
536+
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vec_tq2_0_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
537+
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vec_tq2_0_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
538+
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vec_tq2_0_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
535539
}
536540
#endif
537541

@@ -992,7 +996,7 @@ void write_output_files() {
992996

993997
for (const std::string& btype : btypes) {
994998
for (const auto& tname : type_names) {
995-
if (btype == "q8_1" && !is_legacy_quant(tname)) {
999+
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "tq2_0") {
9961000
continue;
9971001
}
9981002
fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str());

tests/test-backend-ops.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6181,13 +6181,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
61816181
test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32, 1024, 256, 4096*40, { 1, 1}, {1, 1}));
61826182
#endif
61836183

6184+
//test_cases.emplace_back(new test_mul_mat(GGML_TYPE_TQ2_0, GGML_TYPE_F32, 16, 1, 16, {1, 1}, {1, 1}));
6185+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_TQ2_0, GGML_TYPE_F32, 16, 1, 256, {1, 1}, {1, 1}));
6186+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_TQ2_0, GGML_TYPE_F32, 16, 2, 256, {1, 1}, {1, 1}));
6187+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_TQ2_0, GGML_TYPE_F32, 16, 4, 256, {1, 1}, {1, 1}));
6188+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_TQ2_0, GGML_TYPE_F32, 16, 8, 256, {1, 1}, {1, 1}));
6189+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_TQ2_0, GGML_TYPE_F32, 32, 32, 256, {1, 1}, {1, 1}));
6190+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_TQ2_0, GGML_TYPE_F32, 16, 1, 1024, {1, 1}, {1, 1}));
6191+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_TQ2_0, GGML_TYPE_F32, 16, 2, 1024, {1, 1}, {1, 1}));
6192+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_TQ2_0, GGML_TYPE_F32, 16, 4, 1024, {1, 1}, {1, 1}));
6193+
6194+
#if 0
61846195
for (ggml_type type_a : all_types) {
61856196
for (int i = 1; i < 10; ++i) {
61866197
test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
61876198
}
61886199
}
61896200

6190-
#if 1
61916201
for (ggml_type type_a : base_types) {
61926202
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
61936203
std::vector<int> ks = { 256 };

0 commit comments

Comments
 (0)