Skip to content

Commit 911d0d9

Browse files
committed
Add Vulkan TQ2_0 shader.
Signed-off-by: Marcus Edel <[email protected]>
1 parent aafd00f commit 911d0d9

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ struct vk_device_struct {
557557
vk_pipeline pipeline_out_prod_f16_f32;
558558
vk_pipeline pipeline_out_prod_q4_0;
559559
vk_pipeline pipeline_out_prod_q8_0;
560+
vk_pipeline pipeline_out_prod_tq2_0;
560561
vk_pipeline pipeline_argmax_f32;
561562
vk_pipeline pipeline_count_equal_i32;
562563
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -3432,6 +3433,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
34323433
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1);
34333434
ggml_vk_create_pipeline(device, device->pipeline_out_prod_q4_0, "out_prod_q4_0", out_prod_q4_0_len, out_prod_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true);
34343435
ggml_vk_create_pipeline(device, device->pipeline_out_prod_q8_0, "out_prod_q8_0", out_prod_q8_0_len, out_prod_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true);
3436+
ggml_vk_create_pipeline(device, device->pipeline_out_prod_tq2_0, "out_prod_tq2_0", out_prod_tq2_0_len, out_prod_tq2_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true);
34353437

34363438
// TODO: should we have device->subgroup_size here or 0?
34373439
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1);
@@ -7805,6 +7807,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
78057807
if (src0->type == GGML_TYPE_F32) return ctx->device->pipeline_out_prod_f32;
78067808
if (src0->type == GGML_TYPE_Q4_0) return ctx->device->pipeline_out_prod_q4_0;
78077809
if (src0->type == GGML_TYPE_Q8_0) return ctx->device->pipeline_out_prod_q8_0;
7810+
if (src0->type == GGML_TYPE_TQ2_0) return ctx->device->pipeline_out_prod_tq2_0;
78087811
}
78097812
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
78107813
return ctx->device->pipeline_out_prod_f16_f32;
@@ -12558,6 +12561,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1255812561
case GGML_TYPE_F16:
1255912562
case GGML_TYPE_Q4_0:
1256012563
case GGML_TYPE_Q8_0:
12564+
case GGML_TYPE_TQ2_0:
1256112565
return true;
1256212566
default:
1256312567
return false;
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_binary_head.comp"
5+
#include "dequant_funcs.comp"
6+
7+
const uint num_threads = 256;
8+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
9+
10+
void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) {
11+
i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20));
12+
const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
13+
i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20));
14+
const uint i22_offset = i22*p.ne21*p.ne20;
15+
i21 = (idx - i23_offset - i22_offset) / p.ne20;
16+
i20 = idx - i23_offset - i22_offset - i21*p.ne20;
17+
}
18+
19+
void main() {
20+
// num_threads * num_iter must equal 512 to match the wg_denoms and get_idx
21+
const uint num_iter = 2;
22+
23+
const uint broadcast2 = uint(p.param2);
24+
const uint broadcast3 = p.param3;
25+
26+
uint idx = get_idx();
27+
28+
[[unroll]] for (uint it = 0; it < num_iter; ++it) {
29+
if (idx < p.ne) {
30+
uint i0, i1, i2, i3;
31+
get_dst_indices(idx, i0, i1, i2, i3);
32+
33+
float acc = 0.0f;
34+
35+
for (uint k = 0; k < p.ne01; k += 1) {
36+
const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
37+
const uint ib = a_block_base + (i0 / QUANT_K);
38+
const uint r = (i0 % QUANT_K);
39+
const uint iqs = (r % 32u) + 32u * (r / 128u);
40+
const uint sub = (r % 128u) / 32u;
41+
42+
const vec4 v = dequantize4(ib, iqs, 0);
43+
const vec2 dm = get_dm(ib, 0);
44+
45+
float qv = (sub == 0u) ? v.x : (sub == 1u) ? v.y : (sub == 2u) ? v.z : v.w;
46+
const float a_val = qv * dm.x + dm.y;
47+
48+
const uint b_idx = src1_idx(i1, k, i2, i3);
49+
const float b = data_b[get_boffset() + b_idx];
50+
acc += a_val * b;
51+
}
52+
53+
uint d_idx = dst_idx(i0, i1, i2, i3);
54+
data_d[get_doffset() + d_idx] = acc;
55+
}
56+
idx += num_threads;
57+
}
58+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ void process_shaders() {
718718
string_to_spv("out_prod_f16_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
719719
string_to_spv("out_prod_q4_0", "out_prod_q4_0.comp", merge_maps(base_dict, {{"DATA_A_Q4_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
720720
string_to_spv("out_prod_q8_0", "out_prod_q8_0.comp", merge_maps(base_dict, {{"DATA_A_Q8_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
721+
string_to_spv("out_prod_tq2_0", "out_prod_tq2_0.comp", merge_maps(base_dict, {{"DATA_A_TQ2_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
721722

722723
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
723724

0 commit comments

Comments
 (0)