@@ -473,7 +473,10 @@ struct vk_device_struct {
473
473
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
474
474
vk_pipeline pipeline_argsort_f32;
475
475
vk_pipeline pipeline_sum_rows_f32;
476
- vk_pipeline pipeline_out_prod_f32, pipeline_out_prod_f16_f32;
476
+ vk_pipeline pipeline_out_prod_f32;
477
+ vk_pipeline pipeline_out_prod_f16_f32;
478
+ vk_pipeline pipeline_out_prod_q4_0;
479
+ vk_pipeline pipeline_out_prod_q8_0;
477
480
vk_pipeline pipeline_argmax_f32;
478
481
vk_pipeline pipeline_count_equal_i32;
479
482
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -2937,6 +2940,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2937
2940
2938
2941
// TODO: should we have device->subgroup_size here or 0?
2939
2942
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2943
+ ggml_vk_create_pipeline(device, device->pipeline_out_prod_q4_0, "out_prod_q4_0", out_prod_q4_0_len, out_prod_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2944
+ ggml_vk_create_pipeline(device, device->pipeline_out_prod_q8_0, "out_prod_q8_0", out_prod_q8_0_len, out_prod_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2940
2945
2941
2946
// TODO: should we have device->subgroup_size here or 0?
2942
2947
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
@@ -6753,8 +6758,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6753
6758
return nullptr;
6754
6759
}
6755
6760
case GGML_OP_OUT_PROD:
6756
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6757
- return ctx->device->pipeline_out_prod_f32;
6761
+ if (dst->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
6762
+ if (src0->type == GGML_TYPE_F32) return ctx->device->pipeline_out_prod_f32;
6763
+ if (src0->type == GGML_TYPE_Q4_0) return ctx->device->pipeline_out_prod_q4_0;
6764
+ if (src0->type == GGML_TYPE_Q8_0) return ctx->device->pipeline_out_prod_q8_0;
6758
6765
}
6759
6766
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6760
6767
return ctx->device->pipeline_out_prod_f16_f32;
@@ -6931,7 +6938,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6931
6938
}
6932
6939
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
6933
6940
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
6934
- GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
6941
+ GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || op == GGML_OP_OUT_PROD || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
6935
6942
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
6936
6943
GGML_ASSERT(dst->buffer != nullptr);
6937
6944
const uint64_t ne00 = src0->ne[0];
@@ -10622,9 +10629,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10622
10629
case GGML_OP_GROUP_NORM:
10623
10630
case GGML_OP_L2_NORM:
10624
10631
return ggml_is_contiguous(op->src[0]);
10625
- case GGML_OP_OUT_PROD:
10626
- return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10627
- op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
10632
+ case GGML_OP_OUT_PROD: {
10633
+ const ggml_type t0 = op->src[0]->type;
10634
+ const ggml_type t1 = op->src[1]->type;
10635
+ const ggml_type td = op->type;
10636
+ if (td != GGML_TYPE_F32 || t1 != GGML_TYPE_F32) {
10637
+ return false;
10638
+ }
10639
+ switch (t0) {
10640
+ case GGML_TYPE_F32:
10641
+ case GGML_TYPE_F16:
10642
+ case GGML_TYPE_Q4_0:
10643
+ case GGML_TYPE_Q8_0:
10644
+ return true;
10645
+ default:
10646
+ return false;
10647
+ }
10648
+ }
10628
10649
case GGML_OP_ADD:
10629
10650
case GGML_OP_SUB:
10630
10651
case GGML_OP_MUL:
0 commit comments