Skip to content

Commit e9f5d88

Browse files
Italo Nicolamakaveli10
authored andcommitted
Vulkan: add support for fp32 OUT_PROD op
1 parent 0c1ffd1 commit e9f5d88

File tree

3 files changed

+97
-0
lines changed

3 files changed

+97
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ struct vk_device_struct {
473473
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
474474
vk_pipeline pipeline_argsort_f32;
475475
vk_pipeline pipeline_sum_rows_f32;
476+
vk_pipeline pipeline_out_prod_f32;
476477
vk_pipeline pipeline_argmax_f32;
477478
vk_pipeline pipeline_count_equal_i32;
478479
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -2934,6 +2935,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
29342935
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
29352936
}
29362937

2938+
// TODO: should we have device->subgroup_size here or 0?
2939+
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2940+
29372941
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
29382942

29392943
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -6745,6 +6749,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
67456749
}
67466750
return nullptr;
67476751
}
6752+
case GGML_OP_OUT_PROD:
6753+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6754+
return ctx->device->pipeline_out_prod_f32;
6755+
}
6756+
return nullptr;
67486757
case GGML_OP_ARGSORT:
67496758
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
67506759
return ctx->device->pipeline_argsort_f32;
@@ -6829,6 +6838,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
68296838
switch (op) {
68306839
case GGML_OP_CPY:
68316840
case GGML_OP_GET_ROWS:
6841+
case GGML_OP_OUT_PROD:
68326842
case GGML_OP_ADD:
68336843
case GGML_OP_SUB:
68346844
case GGML_OP_MUL:
@@ -7149,6 +7159,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
71497159
case GGML_OP_UPSCALE:
71507160
case GGML_OP_UNARY:
71517161
case GGML_OP_GLU:
7162+
case GGML_OP_OUT_PROD:
71527163
case GGML_OP_CONV_2D_DW:
71537164
{
71547165
uint32_t ne = ggml_nelements(dst);
@@ -7894,6 +7905,24 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct
78947905
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
78957906
}
78967907

7908+
static void ggml_vk_out_prod(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7909+
const uint32_t src0_type_size = ggml_type_size(src0->type);
7910+
const uint32_t src1_type_size = ggml_type_size(src1->type);
7911+
const uint32_t dst_type_size = ggml_type_size(dst->type);
7912+
7913+
const int64_t r2 = src1->ne[2] / src0->ne[2];
7914+
const int64_t r3 = src1->ne[3] / src0->ne[3];
7915+
7916+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_OUT_PROD, {
7917+
(uint32_t)ggml_nelements(dst),
7918+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7919+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7920+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7921+
0,
7922+
0.0f, (float) r2, (uint32_t) r3
7923+
}, dryrun);
7924+
}
7925+
78977926
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
78987927
const int32_t s0 = dst->op_params[0];
78997928
const int32_t s1 = dst->op_params[1];
@@ -9050,6 +9079,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
90509079
case GGML_OP_ROPE_BACK:
90519080
case GGML_OP_MUL_MAT:
90529081
case GGML_OP_MUL_MAT_ID:
9082+
case GGML_OP_OUT_PROD:
90539083
case GGML_OP_ARGSORT:
90549084
case GGML_OP_SUM:
90559085
case GGML_OP_SUM_ROWS:
@@ -9117,6 +9147,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91179147
case GGML_OP_SOFT_MAX_BACK:
91189148
case GGML_OP_ROPE:
91199149
case GGML_OP_ROPE_BACK:
9150+
case GGML_OP_OUT_PROD:
91209151
case GGML_OP_ARGSORT:
91219152
case GGML_OP_SUM:
91229153
case GGML_OP_SUM_ROWS:
@@ -9156,6 +9187,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91569187
case GGML_OP_GET_ROWS:
91579188
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun);
91589189

9190+
break;
9191+
case GGML_OP_OUT_PROD:
9192+
ggml_vk_out_prod(ctx, compute_ctx, src0, src1, node, dryrun);
9193+
91599194
break;
91609195
case GGML_OP_ADD:
91619196
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9457,6 +9492,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
94579492
case GGML_OP_SUM:
94589493
case GGML_OP_SUM_ROWS:
94599494
case GGML_OP_ARGMAX:
9495+
case GGML_OP_OUT_PROD:
94609496
case GGML_OP_COUNT_EQUAL:
94619497
case GGML_OP_IM2COL:
94629498
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -10580,6 +10616,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1058010616
case GGML_OP_GROUP_NORM:
1058110617
case GGML_OP_L2_NORM:
1058210618
return ggml_is_contiguous(op->src[0]);
10619+
case GGML_OP_OUT_PROD:
10620+
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32) &&
10621+
op->type == GGML_TYPE_F32;
1058310622
case GGML_OP_ADD:
1058410623
case GGML_OP_SUB:
1058510624
case GGML_OP_MUL:
@@ -11030,6 +11069,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1103011069
tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
1103111070
} else if (tensor->op == GGML_OP_ADD) {
1103211071
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
11072+
} else if (tensor->op == GGML_OP_OUT_PROD) {
11073+
tensor_clone = ggml_out_prod(ggml_ctx, src_clone[0], src_clone[1]);
1103311074
} else if (tensor->op == GGML_OP_ACC) {
1103411075
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
1103511076
} else if (tensor->op == GGML_OP_NORM) {
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#version 450
2+
3+
#include "generic_binary_head.comp"
4+
#include "types.comp"
5+
6+
const uint num_threads = 256;
7+
8+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
9+
10+
void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) {
11+
i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20));
12+
const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
13+
i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20));
14+
const uint i22_offset = i22*p.ne21*p.ne20;
15+
i21 = (idx - i23_offset - i22_offset) / p.ne20;
16+
i20 = idx - i23_offset - i22_offset - i21*p.ne20;
17+
}
18+
19+
void main() {
20+
// num_threads * num_iter must equal 512 to match the wg_denoms and get_idx
21+
const uint num_iter = 2;
22+
23+
const uint broadcast2 = uint(p.param2);
24+
const uint broadcast3 = p.param3;
25+
26+
uint idx = get_idx();
27+
28+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
29+
if (idx >= p.ne) {
30+
continue;
31+
}
32+
33+
uint i0, i1, i2, i3;
34+
get_dst_indices(idx, i0, i1, i2, i3);
35+
36+
FLOAT_TYPE acc = FLOAT_TYPE(0.0);
37+
38+
for (uint i01 = 0; i01 < p.ne01; ++i01) {
39+
uint a_idx = src0_idx(i0, i01, i2 / broadcast2, i3 / broadcast3);
40+
uint b_idx = src1_idx(i1, i01, i2, i3);
41+
42+
FLOAT_TYPE a_val = FLOAT_TYPE(data_a[get_aoffset() + a_idx]);
43+
FLOAT_TYPE b_val = FLOAT_TYPE(data_b[get_boffset() + b_idx]);
44+
45+
acc += a_val * b_val;
46+
}
47+
48+
uint d_idx = dst_idx(i0, i1, i2, i3);
49+
data_d[get_doffset() + d_idx] = D_TYPE(acc);
50+
51+
idx += num_threads;
52+
}
53+
}
54+

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,8 @@ void process_shaders() {
633633
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
634634
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
635635

636+
string_to_spv("out_prod_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
637+
636638
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
637639

638640
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));

0 commit comments

Comments
 (0)