Skip to content

Commit 2b0c835

Browse files
Italo Nicolamakaveli10
authored andcommitted
Vulkan: add support for f16_f32 OUT_PROD op
1 parent fb0e501 commit 2b0c835

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ struct vk_device_struct {
473473
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
474474
vk_pipeline pipeline_argsort_f32;
475475
vk_pipeline pipeline_sum_rows_f32;
476-
vk_pipeline pipeline_out_prod_f32;
476+
vk_pipeline pipeline_out_prod_f32, pipeline_out_prod_f16_f32;
477477
vk_pipeline pipeline_argmax_f32;
478478
vk_pipeline pipeline_count_equal_i32;
479479
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -2938,6 +2938,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
29382938
// TODO: should we have device->subgroup_size here or 0?
29392939
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
29402940

2941+
// TODO: should we have device->subgroup_size here or 0?
2942+
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2943+
29412944
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
29422945

29432946
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -6753,6 +6756,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
67536756
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
67546757
return ctx->device->pipeline_out_prod_f32;
67556758
}
6759+
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6760+
return ctx->device->pipeline_out_prod_f16_f32;
6761+
}
67566762
return nullptr;
67576763
case GGML_OP_ARGSORT:
67586764
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
@@ -10617,13 +10623,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1061710623
case GGML_OP_L2_NORM:
1061810624
return ggml_is_contiguous(op->src[0]);
1061910625
case GGML_OP_OUT_PROD:
10620-
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32) &&
10621-
op->type == GGML_TYPE_F32;
10626+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10627+
op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
1062210628
case GGML_OP_ADD:
1062310629
case GGML_OP_SUB:
1062410630
case GGML_OP_MUL:
1062510631
case GGML_OP_DIV:
10626-
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10632+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1062710633
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
1062810634
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1062910635
case GGML_OP_SILU_BACK:

ggml/src/ggml-vulkan/vulkan-shaders/out_prod.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#version 450
22

3+
#extension GL_EXT_shader_16bit_storage : require
4+
35
#include "generic_binary_head.comp"
46
#include "types.comp"
57

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ void process_shaders() {
634634
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
635635

636636
string_to_spv("out_prod_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
637+
string_to_spv("out_prod_f16_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
637638

638639
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
639640

0 commit comments

Comments
 (0)