Skip to content

Commit d4c5db4

Browse files
author
Italo Nicola
committed
(wip) Vulkan: Adreno Q4_0 fix
1 parent 7f4b20e commit d4c5db4

File tree

8 files changed

+97
-14
lines changed

8 files changed

+97
-14
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
3535
return (vec2(vui & 0xF, vui >> 4) - 8.0f);
3636
}
3737
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
38-
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
39-
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f);
38+
const vec2 v01 = dequantize(ib, iqs, a_offset);
39+
const vec2 v23 = dequantize(ib, iqs + 1, a_offset);
40+
return vec4(v01.x, v01.y, v23.x, v23.y);
4041
}
4142
#endif
4243

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,33 @@ layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[
6262
#endif
6363

6464
#if defined(DATA_A_Q4_0)
65+
#define BINDING_IDX_K 0
66+
#define BINDING_IDX_V 1
67+
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE data[];} kv_packed[2];
6568
#define BLOCK_BYTE_SIZE 18
6669

6770
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
68-
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
69-
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
71+
uint v00 =
72+
uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 0]);
73+
uint v01 =
74+
uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 1]);
75+
uint v10 =
76+
uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 2]);
77+
uint v11 =
78+
uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 3]);
79+
7080
uint shift = (iqs & 0x10) >> 2;
71-
vui_lo >>= shift;
72-
vui_hi >>= shift;
81+
v00 >>= shift;
82+
v01 >>= shift;
83+
v10 >>= shift;
84+
v11 >>= shift;
85+
86+
v00 = v00 & 0xF;
87+
v01 = v01 & 0xF;
88+
v10 = v10 & 0xF;
89+
v11 = v11 & 0xF;
7390

74-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
91+
return float(kv_packed[binding_idx].data[a_offset + ib].d) * (vec4(v00, v01, v10, v11) - 8.0f);
7592
}
7693
#endif
7794

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,28 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
2424

2525
#if K_PER_ITER == 8
2626
#if QUANT_R == 2
27-
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
28-
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
29-
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
30-
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
27+
// Replicate the original data_b_v4 indexing with /4 rounding
28+
uint idx1 = (j*p.batch_stride_b + b_offset + iybs + iqs);
29+
uint idx2 = (j*p.batch_stride_b + b_offset + iybs + iqs + y_offset);
30+
uint base1 = (idx1 / 4) * 4; // Round down to nearest multiple of 4
31+
uint base2 = (idx2 / 4) * 4; // Round down to nearest multiple of 4
32+
33+
const FLOAT_TYPE bv02_x = FLOAT_TYPE(data_b[base1 + 0]);
34+
const FLOAT_TYPE bv02_y = FLOAT_TYPE(data_b[base1 + 1]);
35+
const FLOAT_TYPE bv02_z = FLOAT_TYPE(data_b[base1 + 2]);
36+
const FLOAT_TYPE bv02_w = FLOAT_TYPE(data_b[base1 + 3]);
37+
const FLOAT_TYPE bv13_x = FLOAT_TYPE(data_b[base2 + 0]);
38+
const FLOAT_TYPE bv13_y = FLOAT_TYPE(data_b[base2 + 1]);
39+
const FLOAT_TYPE bv13_z = FLOAT_TYPE(data_b[base2 + 2]);
40+
const FLOAT_TYPE bv13_w = FLOAT_TYPE(data_b[base2 + 3]);
41+
// XXX this is not guaranteed to be used for Q4, so make sure it works for everything else
42+
#if 1
43+
const vec4 bv0 = vec4(bv02_x, bv13_x, bv02_y, bv13_y);
44+
const vec4 bv1 = vec4(bv02_z, bv13_z, bv02_w, bv13_w);
45+
#else
46+
const vec4 bv0 = vec4(1.0, 1.0, 1.0, 1.0);
47+
const vec4 bv1 = vec4(1.0, 1.0, 1.0, 1.0);
48+
#endif
3149
#else
3250
const FLOAT_TYPE bv00 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) ]);
3351
const FLOAT_TYPE bv01 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) + 1]);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1212
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
13-
#if !defined(DATA_A_Q8_0)
13+
#if !defined(DATA_A_Q8_0) && !defined(DATA_A_Q4_0)
1414
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
1515
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
1616
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
128128
}
129129
#endif
130130

131+
#if defined(DATA_A_Q4_0)
132+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
133+
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
134+
return (vec2(vui & 0xF, vui >> 4) - 8.0f);
135+
}
136+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
137+
const vec2 v01 = dequantize(ib, iqs, a_offset);
138+
const vec2 v23 = dequantize(ib, iqs + 1, a_offset);
139+
return vec4(v01.x, v01.y, v23.x, v23.y);
140+
}
141+
#endif
142+
131143
void main() {
132144
#ifdef NEEDS_INIT_IQ_SHMEM
133145
init_iq_shmem(gl_WorkGroupSize);
@@ -335,6 +347,7 @@ void main() {
335347
const uint ib = idx / 4;
336348
const uint iqs = idx & 0x03;
337349

350+
#if 0
338351
const float d = float(data_a_packed16[ib].d);
339352
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
340353
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
@@ -348,6 +361,20 @@ void main() {
348361
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
349362
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
350363
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
364+
#else
365+
const float d = float(data_a[ib].d);
366+
const vec4 vxy = dequantize4(ib, 4*iqs, 0) * d;
367+
const vec4 vzw = dequantize4(ib, 4*iqs + 2, 0) * d;
368+
369+
buf_a[buf_idx ] = FLOAT_TYPE(vxy.x);
370+
buf_a[buf_idx + 1 ] = FLOAT_TYPE(vxy.z);
371+
buf_a[buf_idx + 2 ] = FLOAT_TYPE(vzw.x);
372+
buf_a[buf_idx + 3 ] = FLOAT_TYPE(vzw.z);
373+
buf_a[buf_idx + 16] = FLOAT_TYPE(vxy.y);
374+
buf_a[buf_idx + 17] = FLOAT_TYPE(vxy.w);
375+
buf_a[buf_idx + 18] = FLOAT_TYPE(vzw.y);
376+
buf_a[buf_idx + 19] = FLOAT_TYPE(vzw.w);
377+
#endif
351378
#elif defined(DATA_A_Q4_1)
352379
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
353380
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2626

27-
#if defined(DATA_A_Q8_0)
27+
#if defined(DATA_A_Q8_0) || defined(DATA_A_Q4_0)
2828
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
2929
#else
3030
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,32 @@
88

99
#if defined(DATA_A_Q4_0)
1010
i32vec2 repack(uint ib, uint iqs) {
11+
#if 0
1112
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
1213
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
1314
data_a[ib].qs[iqs * 2 + 1]);
1415
const uint32_t vui = pack32(quants);
1516
return i32vec2( vui & 0x0F0F0F0F,
1617
(vui >> 4) & 0x0F0F0F0F);
18+
#else
19+
int32_t u0 = int32_t(uint(data_a[ib].qs[iqs * 4]));
20+
int32_t u1 = int32_t(uint(data_a[ib].qs[iqs * 4 + 1]));
21+
int32_t u2 = int32_t(uint(data_a[ib].qs[iqs * 4 + 2]));
22+
int32_t u3 = int32_t(uint(data_a[ib].qs[iqs * 4 + 3]));
23+
24+
int32_t v0 = int32_t(
25+
(u0 & 0xF) |
26+
((u1 & 0xF) << 8) |
27+
((u2 & 0xF) << 16) |
28+
((u3 & 0xF) << 24));
29+
int32_t v1 = int32_t(
30+
((u0 >> 4) & 0xF) |
31+
(((u1 >> 4) & 0xF) << 8) |
32+
(((u2 >> 4) & 0xF) << 16) |
33+
(((u3 >> 4) & 0xF) << 24));
34+
35+
return i32vec2(v0, v1);
36+
#endif
1737
}
1838

1939
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ struct block_q4_0_packed16
6565
#define QUANT_R QUANT_R_Q4_0
6666
#define QUANT_AUXF 1
6767
#define A_TYPE block_q4_0
68-
#define A_TYPE_PACKED16 block_q4_0_packed16
68+
//#define A_TYPE_PACKED16 block_q4_0_packed16
6969
#endif
7070

7171
#define QUANT_K_Q4_1 32

0 commit comments

Comments
 (0)