@@ -473,7 +473,7 @@ struct vk_device_struct {
473
473
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
474
474
vk_pipeline pipeline_argsort_f32;
475
475
vk_pipeline pipeline_sum_rows_f32;
476
- vk_pipeline pipeline_out_prod_f32;
476
+ vk_pipeline pipeline_out_prod_f32, pipeline_out_prod_f16_f32 ;
477
477
vk_pipeline pipeline_argmax_f32;
478
478
vk_pipeline pipeline_count_equal_i32;
479
479
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -2938,6 +2938,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2938
2938
// TODO: should we have device->subgroup_size here or 0?
2939
2939
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2940
2940
2941
+ // TODO: should we have device->subgroup_size here or 0?
2942
+ ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2943
+
2941
2944
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
2942
2945
2943
2946
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -6753,6 +6756,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6753
6756
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6754
6757
return ctx->device->pipeline_out_prod_f32;
6755
6758
}
6759
+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6760
+ return ctx->device->pipeline_out_prod_f16_f32;
6761
+ }
6756
6762
return nullptr;
6757
6763
case GGML_OP_ARGSORT:
6758
6764
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
@@ -10617,13 +10623,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10617
10623
case GGML_OP_L2_NORM:
10618
10624
return ggml_is_contiguous(op->src[0]);
10619
10625
case GGML_OP_OUT_PROD:
10620
- return (op->src[0]->type == GGML_TYPE_F32 && op->src[1 ]->type == GGML_TYPE_F32 ) &&
10621
- op->type == GGML_TYPE_F32;
10626
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0 ]->type == GGML_TYPE_F16 ) &&
10627
+ op->src[1]->type == GGML_TYPE_F32 && op-> type == GGML_TYPE_F32;
10622
10628
case GGML_OP_ADD:
10623
10629
case GGML_OP_SUB:
10624
10630
case GGML_OP_MUL:
10625
10631
case GGML_OP_DIV:
10626
- return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10632
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10627
10633
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
10628
10634
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
10629
10635
case GGML_OP_SILU_BACK:
0 commit comments