Skip to content

Commit 825b09e

Browse files
author
Italo Nicola
committed
(wip) Vulkan: Adreno Q4_1 fix
1 parent d4c5db4 commit 825b09e

File tree

8 files changed

+97
-10
lines changed

8 files changed

+97
-10
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
4747
return vec2(vui & 0xF, vui >> 4);
4848
}
4949
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
50-
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
51-
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12);
50+
const vec2 v01 = dequantize(ib, iqs, a_offset);
51+
const vec2 v23 = dequantize(ib, iqs + 1, a_offset);
52+
return vec4(v01.x, v01.y, v23.x, v23.y);
5253
}
5354
#endif
5455

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,37 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
9292
}
9393
#endif
9494

95+
#if defined(DATA_A_Q4_1)
96+
#define BINDING_IDX_K 0
97+
#define BINDING_IDX_V 1
98+
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE data[];} kv_packed[2];
99+
#define BLOCK_BYTE_SIZE 20
100+
101+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
102+
uint v00 =
103+
uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 0]);
104+
uint v01 =
105+
uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 1]);
106+
uint v10 =
107+
uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 2]);
108+
uint v11 =
109+
uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 3]);
110+
111+
uint shift = (iqs & 0x10) >> 2;
112+
v00 >>= shift;
113+
v01 >>= shift;
114+
v10 >>= shift;
115+
v11 >>= shift;
116+
117+
v00 = v00 & 0xF;
118+
v01 = v01 & 0xF;
119+
v10 = v10 & 0xF;
120+
v11 = v11 & 0xF;
121+
122+
return float(kv_packed[binding_idx].data[a_offset + ib].d) * (vec4(v00, v01, v10, v11) - 8.0f);
123+
}
124+
#endif
125+
95126
#if defined(DATA_A_Q8_0)
96127
#define BINDING_IDX_K 0
97128
#define BINDING_IDX_V 1

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1212
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
13-
#if !defined(DATA_A_Q8_0) && !defined(DATA_A_Q4_0)
13+
#if !defined(DATA_A_Q8_0) && !defined(DATA_A_Q4_0) && !defined(DATA_A_Q4_1)
1414
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
1515
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
1616
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
140140
}
141141
#endif
142142

143+
#if defined(DATA_A_Q4_1)
144+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
145+
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
146+
return vec2(vui & 0xF, vui >> 4);
147+
}
148+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
149+
const vec2 v01 = dequantize(ib, iqs, a_offset);
150+
const vec2 v23 = dequantize(ib, iqs + 1, a_offset);
151+
return vec4(v01.x, v01.y, v23.x, v23.y);
152+
}
153+
#endif
154+
143155
void main() {
144156
#ifdef NEEDS_INIT_IQ_SHMEM
145157
init_iq_shmem(gl_WorkGroupSize);
@@ -382,6 +394,7 @@ void main() {
382394
const uint ib = idx / 4;
383395
const uint iqs = idx & 0x03;
384396

397+
#if 0
385398
const float d = float(data_a_packed16[ib].d);
386399
const float m = float(data_a_packed16[ib].m);
387400
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
@@ -396,6 +409,21 @@ void main() {
396409
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
397410
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
398411
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
412+
#else
413+
const float d = float(data_a[ib].d);
414+
const float m = float(data_a[ib].m);
415+
const vec4 vxy = dequantize4(ib, 4*iqs, 0) * d + m;
416+
const vec4 vzw = dequantize4(ib, 4*iqs + 2, 0) * d + m;
417+
418+
buf_a[buf_idx ] = FLOAT_TYPE(vxy.x);
419+
buf_a[buf_idx + 1 ] = FLOAT_TYPE(vxy.z);
420+
buf_a[buf_idx + 2 ] = FLOAT_TYPE(vzw.x);
421+
buf_a[buf_idx + 3 ] = FLOAT_TYPE(vzw.z);
422+
buf_a[buf_idx + 16] = FLOAT_TYPE(vxy.y);
423+
buf_a[buf_idx + 17] = FLOAT_TYPE(vxy.w);
424+
buf_a[buf_idx + 18] = FLOAT_TYPE(vzw.y);
425+
buf_a[buf_idx + 19] = FLOAT_TYPE(vzw.w);
426+
#endif
399427
#elif defined(DATA_A_Q5_0)
400428
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
401429
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2626

27-
#if defined(DATA_A_Q8_0) || defined(DATA_A_Q4_0)
27+
#if defined(DATA_A_Q8_0) || defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
2828
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
2929
#else
3030
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,30 @@ ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
4343

4444
#if defined(DATA_A_Q4_1)
4545
i32vec2 repack(uint ib, uint iqs) {
46+
#if 0
4647
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
4748
const uint32_t vui = data_a_packed32[ib].qs[iqs];
4849
return i32vec2( vui & 0x0F0F0F0F,
4950
(vui >> 4) & 0x0F0F0F0F);
51+
#else
52+
int32_t u0 = int32_t(uint(data_a[ib].qs[iqs * 4]));
53+
int32_t u1 = int32_t(uint(data_a[ib].qs[iqs * 4 + 1]));
54+
int32_t u2 = int32_t(uint(data_a[ib].qs[iqs * 4 + 2]));
55+
int32_t u3 = int32_t(uint(data_a[ib].qs[iqs * 4 + 3]));
56+
57+
int32_t v0 = int32_t(
58+
(u0 & 0xF) |
59+
((u1 & 0xF) << 8) |
60+
((u2 & 0xF) << 16) |
61+
((u3 & 0xF) << 24));
62+
int32_t v1 = int32_t(
63+
((u0 >> 4) & 0xF) |
64+
(((u1 >> 4) & 0xF) << 8) |
65+
(((u2 >> 4) & 0xF) << 16) |
66+
(((u3 >> 4) & 0xF) << 24));
67+
68+
return i32vec2(v0, v1);
69+
#endif
5070
}
5171

5272
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
@@ -123,8 +143,15 @@ FLOAT_TYPE get_d(uint ib) {
123143
}
124144
#endif
125145

126-
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
146+
#if defined(DATA_A_Q4_1)
147+
FLOAT_TYPE_VEC2 get_dm(uint ib) {
148+
return FLOAT_TYPE_VEC2(data_a[ib].d, data_a[ib].m);
149+
}
150+
#endif
151+
152+
#if defined(DATA_A_Q5_1)
127153
FLOAT_TYPE_VEC2 get_dm(uint ib) {
128154
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
129155
}
130156
#endif
157+

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ struct block_q4_1_packed32
9696
#define QUANT_R QUANT_R_Q4_1
9797
#define QUANT_AUXF 2
9898
#define A_TYPE block_q4_1
99-
#define A_TYPE_PACKED16 block_q4_1_packed16
100-
#define A_TYPE_PACKED32 block_q4_1_packed32
99+
//#define A_TYPE_PACKED16 block_q4_1_packed16
100+
//#define A_TYPE_PACKED32 block_q4_1_packed32
101101
#endif
102102

103103
#define QUANT_K_Q5_0 32

tests/test-backend-ops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4922,7 +4922,7 @@ struct test_falcon : public test_llm {
49224922
static const ggml_type all_types[] = {
49234923
GGML_TYPE_F32, GGML_TYPE_F16, // GGML_TYPE_BF16,
49244924
GGML_TYPE_Q4_0,
4925-
// GGML_TYPE_Q4_1,
4925+
GGML_TYPE_Q4_1,
49264926
// GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
49274927
GGML_TYPE_Q8_0,
49284928
// GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
@@ -4938,14 +4938,14 @@ static const ggml_type base_types[] = {
49384938
GGML_TYPE_F32, GGML_TYPE_F16,
49394939
GGML_TYPE_Q8_0, // for I8MM tests
49404940
GGML_TYPE_Q4_0,
4941-
// GGML_TYPE_Q4_1, // for I8MM tests
4941+
GGML_TYPE_Q4_1, // for I8MM tests
49424942
// GGML_TYPE_Q4_K,
49434943
// GGML_TYPE_IQ2_XXS
49444944
};
49454945

49464946
static const ggml_type other_types[] = {
49474947
GGML_TYPE_Q4_0,
4948-
// GGML_TYPE_Q4_1,
4948+
GGML_TYPE_Q4_1,
49494949
// GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
49504950
GGML_TYPE_Q8_0,
49514951
// GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,

0 commit comments

Comments
 (0)