@@ -473,6 +473,7 @@ struct vk_device_struct {
473
473
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
474
474
vk_pipeline pipeline_argsort_f32;
475
475
vk_pipeline pipeline_sum_rows_f32;
476
+ vk_pipeline pipeline_out_prod_f32;
476
477
vk_pipeline pipeline_argmax_f32;
477
478
vk_pipeline pipeline_count_equal_i32;
478
479
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -2934,6 +2935,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2934
2935
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2935
2936
}
2936
2937
2938
+ // TODO: should we have device->subgroup_size here or 0?
2939
+ ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2940
+
2937
2941
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
2938
2942
2939
2943
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -6745,6 +6749,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6745
6749
}
6746
6750
return nullptr;
6747
6751
}
6752
+ case GGML_OP_OUT_PROD:
6753
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6754
+ return ctx->device->pipeline_out_prod_f32;
6755
+ }
6756
+ return nullptr;
6748
6757
case GGML_OP_ARGSORT:
6749
6758
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
6750
6759
return ctx->device->pipeline_argsort_f32;
@@ -6829,6 +6838,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6829
6838
switch (op) {
6830
6839
case GGML_OP_CPY:
6831
6840
case GGML_OP_GET_ROWS:
6841
+ case GGML_OP_OUT_PROD:
6832
6842
case GGML_OP_ADD:
6833
6843
case GGML_OP_SUB:
6834
6844
case GGML_OP_MUL:
@@ -7149,6 +7159,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7149
7159
case GGML_OP_UPSCALE:
7150
7160
case GGML_OP_UNARY:
7151
7161
case GGML_OP_GLU:
7162
+ case GGML_OP_OUT_PROD:
7152
7163
case GGML_OP_CONV_2D_DW:
7153
7164
{
7154
7165
uint32_t ne = ggml_nelements(dst);
@@ -7894,6 +7905,24 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct
7894
7905
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7895
7906
}
7896
7907
7908
+ static void ggml_vk_out_prod(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7909
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7910
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7911
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7912
+
7913
+ const int64_t r2 = src1->ne[2] / src0->ne[2];
7914
+ const int64_t r3 = src1->ne[3] / src0->ne[3];
7915
+
7916
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_OUT_PROD, {
7917
+ (uint32_t)ggml_nelements(dst),
7918
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7919
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7920
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7921
+ 0,
7922
+ 0.0f, (float) r2, (uint32_t) r3
7923
+ }, dryrun);
7924
+ }
7925
+
7897
7926
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7898
7927
const int32_t s0 = dst->op_params[0];
7899
7928
const int32_t s1 = dst->op_params[1];
@@ -9050,6 +9079,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9050
9079
case GGML_OP_ROPE_BACK:
9051
9080
case GGML_OP_MUL_MAT:
9052
9081
case GGML_OP_MUL_MAT_ID:
9082
+ case GGML_OP_OUT_PROD:
9053
9083
case GGML_OP_ARGSORT:
9054
9084
case GGML_OP_SUM:
9055
9085
case GGML_OP_SUM_ROWS:
@@ -9117,6 +9147,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9117
9147
case GGML_OP_SOFT_MAX_BACK:
9118
9148
case GGML_OP_ROPE:
9119
9149
case GGML_OP_ROPE_BACK:
9150
+ case GGML_OP_OUT_PROD:
9120
9151
case GGML_OP_ARGSORT:
9121
9152
case GGML_OP_SUM:
9122
9153
case GGML_OP_SUM_ROWS:
@@ -9156,6 +9187,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9156
9187
case GGML_OP_GET_ROWS:
9157
9188
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun);
9158
9189
9190
+ break;
9191
+ case GGML_OP_OUT_PROD:
9192
+ ggml_vk_out_prod(ctx, compute_ctx, src0, src1, node, dryrun);
9193
+
9159
9194
break;
9160
9195
case GGML_OP_ADD:
9161
9196
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9457,6 +9492,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9457
9492
case GGML_OP_SUM:
9458
9493
case GGML_OP_SUM_ROWS:
9459
9494
case GGML_OP_ARGMAX:
9495
+ case GGML_OP_OUT_PROD:
9460
9496
case GGML_OP_COUNT_EQUAL:
9461
9497
case GGML_OP_IM2COL:
9462
9498
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -10580,6 +10616,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10580
10616
case GGML_OP_GROUP_NORM:
10581
10617
case GGML_OP_L2_NORM:
10582
10618
return ggml_is_contiguous(op->src[0]);
10619
+ case GGML_OP_OUT_PROD:
10620
+ return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32) &&
10621
+ op->type == GGML_TYPE_F32;
10583
10622
case GGML_OP_ADD:
10584
10623
case GGML_OP_SUB:
10585
10624
case GGML_OP_MUL:
@@ -11030,6 +11069,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11030
11069
tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
11031
11070
} else if (tensor->op == GGML_OP_ADD) {
11032
11071
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
11072
+ } else if (tensor->op == GGML_OP_OUT_PROD) {
11073
+ tensor_clone = ggml_out_prod(ggml_ctx, src_clone[0], src_clone[1]);
11033
11074
} else if (tensor->op == GGML_OP_ACC) {
11034
11075
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
11035
11076
} else if (tensor->op == GGML_OP_NORM) {
0 commit comments