@@ -543,6 +543,8 @@ struct vk_device_struct {
543543 vk_pipeline pipeline_relu[2];
544544 vk_pipeline pipeline_tanh[2];
545545 vk_pipeline pipeline_sigmoid[2];
546+ vk_pipeline pipeline_hardsigmoid[2];
547+ vk_pipeline pipeline_hardswish[2];
546548
547549 vk_pipeline pipeline_geglu[2];
548550 vk_pipeline pipeline_reglu[2];
@@ -3324,6 +3326,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
33243326 CREATE_UNARY(relu)
33253327 CREATE_UNARY(tanh)
33263328 CREATE_UNARY(sigmoid)
3329+ CREATE_UNARY(hardsigmoid)
3330+ CREATE_UNARY(hardswish)
33273331#undef CREATE_UNARY
33283332
33293333#define CREATE_GLU(name) \
@@ -7656,6 +7660,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
76567660 return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
76577661 case GGML_UNARY_OP_SIGMOID:
76587662 return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
7663+ case GGML_UNARY_OP_HARDSIGMOID:
7664+ return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
7665+ case GGML_UNARY_OP_HARDSWISH:
7666+ return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
76597667 default:
76607668 break;
76617669 }
@@ -10330,6 +10338,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1033010338 case GGML_UNARY_OP_RELU:
1033110339 case GGML_UNARY_OP_TANH:
1033210340 case GGML_UNARY_OP_SIGMOID:
10341+ case GGML_UNARY_OP_HARDSIGMOID:
10342+ case GGML_UNARY_OP_HARDSWISH:
1033310343 break;
1033410344 default:
1033510345 return false;
@@ -10711,6 +10721,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1071110721 case GGML_UNARY_OP_RELU:
1071210722 case GGML_UNARY_OP_TANH:
1071310723 case GGML_UNARY_OP_SIGMOID:
10724+ case GGML_UNARY_OP_HARDSIGMOID:
10725+ case GGML_UNARY_OP_HARDSWISH:
1071410726 ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
1071510727 break;
1071610728 default:
@@ -10955,6 +10967,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1095510967 case GGML_UNARY_OP_RELU:
1095610968 case GGML_UNARY_OP_TANH:
1095710969 case GGML_UNARY_OP_SIGMOID:
10970+ case GGML_UNARY_OP_HARDSIGMOID:
10971+ case GGML_UNARY_OP_HARDSWISH:
1095810972 buf = tensor->buffer;
1095910973 break;
1096010974 default:
@@ -12105,6 +12119,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1210512119 case GGML_UNARY_OP_RELU:
1210612120 case GGML_UNARY_OP_TANH:
1210712121 case GGML_UNARY_OP_SIGMOID:
12122+ case GGML_UNARY_OP_HARDSIGMOID:
12123+ case GGML_UNARY_OP_HARDSWISH:
1210812124 return ggml_is_contiguous(op->src[0]) &&
1210912125 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1211012126 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -12921,6 +12937,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1292112937 case GGML_UNARY_OP_SIGMOID:
1292212938 tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
1292312939 break;
12940+ case GGML_UNARY_OP_HARDSIGMOID:
12941+ tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
12942+ break;
12943+ case GGML_UNARY_OP_HARDSWISH:
12944+ tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
12945+ break;
1292412946 default:
1292512947 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1292612948 GGML_ABORT("fatal error");
0 commit comments