Skip to content

Commit 0aef6c8

Browse files
zoqmakaveli10
authored andcommitted
Vulkan: Add Q4_0/Q8_0 OUT_PROD Vulkan support
1 parent 2b0c835 commit 0aef6c8

File tree

4 files changed

+141
-7
lines changed

4 files changed

+141
-7
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,10 @@ struct vk_device_struct {
473473
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
474474
vk_pipeline pipeline_argsort_f32;
475475
vk_pipeline pipeline_sum_rows_f32;
476-
vk_pipeline pipeline_out_prod_f32, pipeline_out_prod_f16_f32;
476+
vk_pipeline pipeline_out_prod_f32;
477+
vk_pipeline pipeline_out_prod_f16_f32;
478+
vk_pipeline pipeline_out_prod_q4_0;
479+
vk_pipeline pipeline_out_prod_q8_0;
477480
vk_pipeline pipeline_argmax_f32;
478481
vk_pipeline pipeline_count_equal_i32;
479482
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -2937,6 +2940,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29372940

29382941
// TODO: should we have device->subgroup_size here or 0?
29392942
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2943+
ggml_vk_create_pipeline(device, device->pipeline_out_prod_q4_0, "out_prod_q4_0", out_prod_q4_0_len, out_prod_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2944+
ggml_vk_create_pipeline(device, device->pipeline_out_prod_q8_0, "out_prod_q8_0", out_prod_q8_0_len, out_prod_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
29402945

29412946
// TODO: should we have device->subgroup_size here or 0?
29422947
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
@@ -6753,8 +6758,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
67536758
return nullptr;
67546759
}
67556760
case GGML_OP_OUT_PROD:
6756-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6757-
return ctx->device->pipeline_out_prod_f32;
6761+
if (dst->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
6762+
if (src0->type == GGML_TYPE_F32) return ctx->device->pipeline_out_prod_f32;
6763+
if (src0->type == GGML_TYPE_Q4_0) return ctx->device->pipeline_out_prod_q4_0;
6764+
if (src0->type == GGML_TYPE_Q8_0) return ctx->device->pipeline_out_prod_q8_0;
67586765
}
67596766
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
67606767
return ctx->device->pipeline_out_prod_f16_f32;
@@ -6931,7 +6938,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
69316938
}
69326939
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
69336940
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
6934-
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
6941+
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || op == GGML_OP_OUT_PROD || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
69356942
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
69366943
GGML_ASSERT(dst->buffer != nullptr);
69376944
const uint64_t ne00 = src0->ne[0];
@@ -10622,9 +10629,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1062210629
case GGML_OP_GROUP_NORM:
1062310630
case GGML_OP_L2_NORM:
1062410631
return ggml_is_contiguous(op->src[0]);
10625-
case GGML_OP_OUT_PROD:
10626-
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10627-
op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
10632+
case GGML_OP_OUT_PROD: {
10633+
const ggml_type t0 = op->src[0]->type;
10634+
const ggml_type t1 = op->src[1]->type;
10635+
const ggml_type td = op->type;
10636+
if (td != GGML_TYPE_F32 || t1 != GGML_TYPE_F32) {
10637+
return false;
10638+
}
10639+
switch (t0) {
10640+
case GGML_TYPE_F32:
10641+
case GGML_TYPE_F16:
10642+
case GGML_TYPE_Q4_0:
10643+
case GGML_TYPE_Q8_0:
10644+
return true;
10645+
default:
10646+
return false;
10647+
}
10648+
}
1062810649
case GGML_OP_ADD:
1062910650
case GGML_OP_SUB:
1063010651
case GGML_OP_MUL:
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_binary_head.comp"
5+
#include "dequant_funcs.comp"
6+
7+
const uint num_threads = 256;
8+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
9+
10+
void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) {
11+
i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20));
12+
const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
13+
i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20));
14+
const uint i22_offset = i22*p.ne21*p.ne20;
15+
i21 = (idx - i23_offset - i22_offset) / p.ne20;
16+
i20 = idx - i23_offset - i22_offset - i21*p.ne20;
17+
}
18+
19+
void main() {
20+
// num_threads * num_iter must equal 512 to match the wg_denoms and get_idx
21+
const uint num_iter = 2;
22+
23+
const uint broadcast2 = uint(p.param2);
24+
const uint broadcast3 = p.param3;
25+
26+
uint idx = get_idx();
27+
28+
[[unroll]] for (uint it = 0; it < num_iter; ++it) {
29+
if (idx < p.ne) {
30+
uint i0, i1, i2, i3;
31+
get_dst_indices(idx, i0, i1, i2, i3);
32+
33+
float acc = 0.0f;
34+
35+
for (uint k = 0; k < p.ne01; k += 1) {
36+
const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
37+
const uint ib = a_block_base + (i0 / QUANT_K);
38+
const uint iqs = i0 % (QUANT_K / QUANT_R);
39+
const uint upper = (i0 % QUANT_K) / (QUANT_K / QUANT_R);
40+
const uint lower = 1 - upper;
41+
42+
const vec2 v = dequantize(ib, iqs, 0);
43+
const vec2 dm = get_dm(ib, 0);
44+
45+
const float a_val = (v.x * lower + v.y * upper) * dm.x + dm.y;
46+
47+
const uint b_idx = src1_idx(i1, k, i2, i3);
48+
const float b = data_b[get_boffset() + b_idx];
49+
acc += a_val * b;
50+
}
51+
52+
uint d_idx = dst_idx(i0, i1, i2, i3);
53+
data_d[get_doffset() + d_idx] = acc;
54+
}
55+
idx += num_threads;
56+
}
57+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_binary_head.comp"
5+
#include "dequant_funcs.comp"
6+
7+
const uint num_threads = 256;
8+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
9+
10+
void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) {
11+
i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20));
12+
const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
13+
i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20));
14+
const uint i22_offset = i22*p.ne21*p.ne20;
15+
i21 = (idx - i23_offset - i22_offset) / p.ne20;
16+
i20 = idx - i23_offset - i22_offset - i21*p.ne20;
17+
}
18+
19+
void main() {
20+
// num_threads * num_iter must equal 512 to match the wg_denoms and get_idx
21+
const uint num_iter = 2;
22+
23+
const uint broadcast2 = uint(p.param2);
24+
const uint broadcast3 = p.param3;
25+
26+
uint idx = get_idx();
27+
28+
[[unroll]] for (uint it = 0; it < num_iter; ++it) {
29+
if (idx < p.ne) {
30+
uint i0, i1, i2, i3;
31+
get_dst_indices(idx, i0, i1, i2, i3);
32+
33+
float acc = 0.0f;
34+
35+
for (uint k = 0; k < p.ne01; k += 1) {
36+
const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
37+
const uint ib = a_block_base + (i0 / QUANT_K);
38+
const uint iqs = (i0 % QUANT_K) / QUANT_R;
39+
40+
const vec2 v = dequantize(ib, iqs, 0);
41+
const vec2 dm = get_dm(ib, 0);
42+
const float a_val = v.x * dm.x + dm.y;
43+
44+
const uint b_idx = src1_idx(i1, k, i2, i3);
45+
const float b = data_b[get_boffset() + b_idx];
46+
acc += a_val * b;
47+
}
48+
49+
uint d_idx = dst_idx(i0, i1, i2, i3);
50+
data_d[get_doffset() + d_idx] = acc;
51+
}
52+
idx += num_threads;
53+
}
54+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,8 @@ void process_shaders() {
635635

636636
string_to_spv("out_prod_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
637637
string_to_spv("out_prod_f16_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
638+
string_to_spv("out_prod_q4_0", "out_prod_q4_0.comp", merge_maps(base_dict, {{"DATA_A_Q4_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
639+
string_to_spv("out_prod_q8_0", "out_prod_q8_0.comp", merge_maps(base_dict, {{"DATA_A_Q8_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
638640

639641
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
640642

0 commit comments

Comments
 (0)