Skip to content

Commit a62ab54

Browse files
relent95walidbr
authored andcommitted
ggml vulkan: add hardsigmoid and hardswish operations (ggml-org#15762)
1 parent 6545c87 commit a62ab54

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,8 @@ struct vk_device_struct {
543543
vk_pipeline pipeline_relu[2];
544544
vk_pipeline pipeline_tanh[2];
545545
vk_pipeline pipeline_sigmoid[2];
546+
vk_pipeline pipeline_hardsigmoid[2];
547+
vk_pipeline pipeline_hardswish[2];
546548

547549
vk_pipeline pipeline_geglu[2];
548550
vk_pipeline pipeline_reglu[2];
@@ -3324,6 +3326,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
33243326
CREATE_UNARY(relu)
33253327
CREATE_UNARY(tanh)
33263328
CREATE_UNARY(sigmoid)
3329+
CREATE_UNARY(hardsigmoid)
3330+
CREATE_UNARY(hardswish)
33273331
#undef CREATE_UNARY
33283332

33293333
#define CREATE_GLU(name) \
@@ -7656,6 +7660,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
76567660
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
76577661
case GGML_UNARY_OP_SIGMOID:
76587662
return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
7663+
case GGML_UNARY_OP_HARDSIGMOID:
7664+
return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
7665+
case GGML_UNARY_OP_HARDSWISH:
7666+
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
76597667
default:
76607668
break;
76617669
}
@@ -10330,6 +10338,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1033010338
case GGML_UNARY_OP_RELU:
1033110339
case GGML_UNARY_OP_TANH:
1033210340
case GGML_UNARY_OP_SIGMOID:
10341+
case GGML_UNARY_OP_HARDSIGMOID:
10342+
case GGML_UNARY_OP_HARDSWISH:
1033310343
break;
1033410344
default:
1033510345
return false;
@@ -10711,6 +10721,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1071110721
case GGML_UNARY_OP_RELU:
1071210722
case GGML_UNARY_OP_TANH:
1071310723
case GGML_UNARY_OP_SIGMOID:
10724+
case GGML_UNARY_OP_HARDSIGMOID:
10725+
case GGML_UNARY_OP_HARDSWISH:
1071410726
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
1071510727
break;
1071610728
default:
@@ -10955,6 +10967,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1095510967
case GGML_UNARY_OP_RELU:
1095610968
case GGML_UNARY_OP_TANH:
1095710969
case GGML_UNARY_OP_SIGMOID:
10970+
case GGML_UNARY_OP_HARDSIGMOID:
10971+
case GGML_UNARY_OP_HARDSWISH:
1095810972
buf = tensor->buffer;
1095910973
break;
1096010974
default:
@@ -12105,6 +12119,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1210512119
case GGML_UNARY_OP_RELU:
1210612120
case GGML_UNARY_OP_TANH:
1210712121
case GGML_UNARY_OP_SIGMOID:
12122+
case GGML_UNARY_OP_HARDSIGMOID:
12123+
case GGML_UNARY_OP_HARDSWISH:
1210812124
return ggml_is_contiguous(op->src[0]) &&
1210912125
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1211012126
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -12921,6 +12937,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1292112937
case GGML_UNARY_OP_SIGMOID:
1292212938
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
1292312939
break;
12940+
case GGML_UNARY_OP_HARDSIGMOID:
12941+
tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
12942+
break;
12943+
case GGML_UNARY_OP_HARDSWISH:
12944+
tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
12945+
break;
1292412946
default:
1292512947
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1292612948
GGML_ABORT("fatal error");
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
const float x = float(data_a[i]);
21+
data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
22+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
const float x = float(data_a[i]);
21+
data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
22+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,10 @@ void process_shaders() {
657657
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
658658
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
659659
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
660+
string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
661+
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
662+
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
663+
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
660664

661665
for (auto rte : {false, true}) {
662666
std::string suffix = rte ? "_rte" : "";

0 commit comments

Comments
 (0)