Skip to content

Commit dbed612

Browse files
authored
vulkan: add LOG operation support for F32 and F16 (ggml-org#17183)
* vulkan: add LOG operation support for F32 and F16 Part of ggml-org#14909. * vulkan: Fix LOG operation types * docs: Update operation support documentation for Vulkan LOG operation * vulkan: fix log_f16 shader * docs: restore missing LOG test cases and regenerate ops.md
1 parent 80deff3 commit dbed612

File tree

5 files changed

+50
-5
lines changed

5 files changed

+50
-5
lines changed

docs/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Legend:
6363
| IM2COL_3D ||||||||||
6464
| L2_NORM ||||||||||
6565
| LEAKY_RELU |||||||| 🟡 ||
66-
| LOG ||||||| 🟡 | ||
66+
| LOG ||||||| 🟡 | ||
6767
| MEAN ||||||||||
6868
| MUL ||||| 🟡 | 🟡 ||||
6969
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |

docs/ops/Vulkan.csv

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8627,7 +8627,7 @@
86278627
"Vulkan0","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=129","support","1","yes","Vulkan"
86288628
"Vulkan0","SQR","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan"
86298629
"Vulkan0","SQRT","type=f16,ne=[10,3,3,2]","support","0","no","Vulkan"
8630-
"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan"
8630+
"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","1","yes","Vulkan"
86318631
"Vulkan0","SIN","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
86328632
"Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
86338633
"Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
@@ -8638,7 +8638,7 @@
86388638
"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
86398639
"Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
86408640
"Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
8641-
"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
8641+
"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
86428642
"Vulkan0","SIN","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
86438643
"Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
86448644
"Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
@@ -8649,7 +8649,7 @@
86498649
"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
86508650
"Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
86518651
"Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan"
8652-
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
8652+
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
86538653
"Vulkan0","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
86548654
"Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
86558655
"Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
@@ -8660,7 +8660,7 @@
86608660
"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
86618661
"Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
86628662
"Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
8663-
"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
8663+
"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
86648664
"Vulkan0","SIN","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
86658665
"Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
86668666
"Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ struct vk_device_struct {
629629
vk_pipeline pipeline_sqrt_f32;
630630
vk_pipeline pipeline_sin_f32;
631631
vk_pipeline pipeline_cos_f32;
632+
vk_pipeline pipeline_log[2];
632633
vk_pipeline pipeline_clamp_f32;
633634
vk_pipeline pipeline_pad_f32;
634635
vk_pipeline pipeline_roll_f32;
@@ -3792,6 +3793,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
37923793
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
37933794
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
37943795
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3796+
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3797+
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
37953798

37963799
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
37973800

@@ -8126,6 +8129,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
81268129
return ctx->device->pipeline_cos_f32;
81278130
}
81288131
return nullptr;
8132+
case GGML_OP_LOG:
8133+
if (src0->type == dst->type &&
8134+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8135+
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
8136+
}
8137+
return nullptr;
81298138
case GGML_OP_CLAMP:
81308139
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
81318140
return ctx->device->pipeline_clamp_f32;
@@ -8534,6 +8543,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
85348543
case GGML_OP_SQRT:
85358544
case GGML_OP_SIN:
85368545
case GGML_OP_COS:
8546+
case GGML_OP_LOG:
85378547
case GGML_OP_CLAMP:
85388548
case GGML_OP_PAD:
85398549
case GGML_OP_REPEAT:
@@ -8806,6 +8816,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88068816
case GGML_OP_SQRT:
88078817
case GGML_OP_SIN:
88088818
case GGML_OP_COS:
8819+
case GGML_OP_LOG:
88098820
case GGML_OP_CLAMP:
88108821
case GGML_OP_PAD:
88118822
case GGML_OP_ROLL:
@@ -9414,6 +9425,10 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
94149425
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst));
94159426
}
94169427

9428+
static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9429+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
9430+
}
9431+
94179432
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
94189433
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
94199434
p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11209,6 +11224,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1120911224
case GGML_OP_SQRT:
1121011225
case GGML_OP_SIN:
1121111226
case GGML_OP_COS:
11227+
case GGML_OP_LOG:
1121211228
case GGML_OP_CLAMP:
1121311229
case GGML_OP_PAD:
1121411230
case GGML_OP_ROLL:
@@ -11433,6 +11449,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1143311449
case GGML_OP_COS:
1143411450
ggml_vk_cos(ctx, compute_ctx, src0, node);
1143511451

11452+
break;
11453+
case GGML_OP_LOG:
11454+
ggml_vk_log(ctx, compute_ctx, src0, node);
11455+
1143611456
break;
1143711457
case GGML_OP_CLAMP:
1143811458
ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -11703,6 +11723,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1170311723
case GGML_OP_SQRT:
1170411724
case GGML_OP_SIN:
1170511725
case GGML_OP_COS:
11726+
case GGML_OP_LOG:
1170611727
case GGML_OP_CLAMP:
1170711728
case GGML_OP_PAD:
1170811729
case GGML_OP_ROLL:
@@ -13664,6 +13685,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1366413685
case GGML_OP_OPT_STEP_ADAMW:
1366513686
case GGML_OP_OPT_STEP_SGD:
1366613687
return op->src[0]->type == GGML_TYPE_F32;
13688+
case GGML_OP_LOG:
13689+
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
1366713690
case GGML_OP_ARGSORT:
1366813691
return op->ne[0] <= max_argsort_cols;
1366913692
case GGML_OP_UPSCALE:
@@ -14159,6 +14182,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1415914182
tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
1416014183
} else if (tensor->op == GGML_OP_COS) {
1416114184
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
14185+
} else if (tensor->op == GGML_OP_LOG) {
14186+
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
1416214187
} else if (tensor->op == GGML_OP_CLAMP) {
1416314188
const float * params = (const float *)tensor->op_params;
1416414189
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#version 450
2+
3+
#include "types.glsl"
4+
#include "generic_unary_head.glsl"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
void main() {
9+
const uint idx = get_idx();
10+
11+
if (idx >= p.ne) {
12+
return;
13+
}
14+
15+
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
16+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));
17+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,9 @@ void process_shaders() {
802802

803803
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
804804

805+
string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
806+
string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
807+
805808
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
806809

807810
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

0 commit comments

Comments
 (0)