@@ -541,6 +541,7 @@ struct vk_device_struct {
541541
542542 vk_pipeline pipeline_leaky_relu_f32;
543543 vk_pipeline pipeline_silu_back_f32;
544+ vk_pipeline pipeline_geglu_back_f32;
544545 vk_pipeline pipeline_diag_mask_inf_f32;
545546 vk_pipeline pipeline_cross_entropy_loss_back_f32;
546547 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
@@ -3393,6 +3394,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
33933394 ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
33943395 ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
33953396
3397+ ggml_vk_create_pipeline(device, device->pipeline_geglu_back_f32, "geglu_back_f32", geglu_back_f32_len, geglu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3398+
33963399 ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
33973400
33983401 ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -7634,6 +7637,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
76347637 return ctx->device->pipeline_silu_back_f32;
76357638 }
76367639 return nullptr;
7640+ case GGML_OP_GEGLU_BACK:
7641+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7642+ return ctx->device->pipeline_geglu_back_f32;
7643+ }
7644+ return nullptr;
76377645 case GGML_OP_NORM:
76387646 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
76397647 return ctx->device->pipeline_norm_f32;
@@ -9064,6 +9072,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
90649072 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
90659073}
90669074
9075+ static void ggml_vk_geglu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
9076+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GEGLU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
9077+ }
9078+
90679079static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
90689080 float * op_params = (float *)dst->op_params;
90699081
@@ -10585,6 +10597,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1058510597 case GGML_OP_CONT:
1058610598 case GGML_OP_DUP:
1058710599 case GGML_OP_SILU_BACK:
10600+ case GGML_OP_GEGLU_BACK:
1058810601 case GGML_OP_NORM:
1058910602 case GGML_OP_GROUP_NORM:
1059010603 case GGML_OP_RMS_NORM:
@@ -10658,6 +10671,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1065810671 case GGML_OP_CONT:
1065910672 case GGML_OP_DUP:
1066010673 case GGML_OP_SILU_BACK:
10674+ case GGML_OP_GEGLU_BACK:
1066110675 case GGML_OP_NORM:
1066210676 case GGML_OP_GROUP_NORM:
1066310677 case GGML_OP_RMS_NORM:
@@ -10872,6 +10886,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1087210886 case GGML_OP_SILU_BACK:
1087310887 ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
1087410888
10889+ break;
10890+ case GGML_OP_GEGLU_BACK:
10891+ ggml_vk_geglu_back(ctx, compute_ctx, src0, src1, node, dryrun);
10892+
1087510893 break;
1087610894 case GGML_OP_NORM:
1087710895 ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -11116,6 +11134,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1111611134 case GGML_OP_CONT:
1111711135 case GGML_OP_DUP:
1111811136 case GGML_OP_SILU_BACK:
11137+ case GGML_OP_GEGLU_BACK:
1111911138 case GGML_OP_NORM:
1112011139 case GGML_OP_GROUP_NORM:
1112111140 case GGML_OP_RMS_NORM:
@@ -12544,6 +12563,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1254412563 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
1254512564 op->type == GGML_TYPE_F32;
1254612565 case GGML_OP_SILU_BACK:
12566+ case GGML_OP_GEGLU_BACK:
1254712567 case GGML_OP_RMS_NORM_BACK:
1254812568 case GGML_OP_SQR:
1254912569 case GGML_OP_SQRT:
0 commit comments