@@ -557,6 +557,7 @@ struct vk_device_struct {
557557 vk_pipeline pipeline_out_prod_f16_f32;
558558 vk_pipeline pipeline_out_prod_q4_0;
559559 vk_pipeline pipeline_out_prod_q8_0;
560+ vk_pipeline pipeline_out_prod_tq2_0;
560561 vk_pipeline pipeline_argmax_f32;
561562 vk_pipeline pipeline_count_equal_i32;
562563 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -3432,6 +3433,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
34323433 ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1);
34333434 ggml_vk_create_pipeline(device, device->pipeline_out_prod_q4_0, "out_prod_q4_0", out_prod_q4_0_len, out_prod_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true);
34343435 ggml_vk_create_pipeline(device, device->pipeline_out_prod_q8_0, "out_prod_q8_0", out_prod_q8_0_len, out_prod_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true);
3436+ ggml_vk_create_pipeline(device, device->pipeline_out_prod_tq2_0, "out_prod_tq2_0", out_prod_tq2_0_len, out_prod_tq2_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true);
34353437
34363438 // TODO: should we have device->subgroup_size here or 0?
34373439 ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1);
@@ -7805,6 +7807,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
78057807 if (src0->type == GGML_TYPE_F32) return ctx->device->pipeline_out_prod_f32;
78067808 if (src0->type == GGML_TYPE_Q4_0) return ctx->device->pipeline_out_prod_q4_0;
78077809 if (src0->type == GGML_TYPE_Q8_0) return ctx->device->pipeline_out_prod_q8_0;
7810+ if (src0->type == GGML_TYPE_TQ2_0) return ctx->device->pipeline_out_prod_tq2_0;
78087811 }
78097812 if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
78107813 return ctx->device->pipeline_out_prod_f16_f32;
@@ -12558,6 +12561,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1255812561 case GGML_TYPE_F16:
1255912562 case GGML_TYPE_Q4_0:
1256012563 case GGML_TYPE_Q8_0:
12564+ case GGML_TYPE_TQ2_0:
1256112565 return true;
1256212566 default:
1256312567 return false;
0 commit comments