Skip to content

Commit 0721550

Browse files
committed
vulkan: Fix cross-entropy-loss-back dispatch size and wg denominator
Signed-off-by: vineet <[email protected]>
1 parent 25c5316 commit 0721550

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2916,7 +2916,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29162916

29172917
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
29182918

2919-
ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {256, 1, 1}, { device->subgroup_size }, 1);
2919+
ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
29202920

29212921
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
29222922
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
@@ -7092,7 +7092,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
70927092
case GGML_OP_L2_NORM:
70937093
case GGML_OP_SOFT_MAX:
70947094
case GGML_OP_SOFT_MAX_BACK:
7095-
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
70967095
case GGML_OP_SUM_ROWS:
70977096
case GGML_OP_ARGMAX:
70987097
{
@@ -7105,6 +7104,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
71057104
elements = { nr, 1, 1 };
71067105
}
71077106
} break;
7107+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
7108+
{
7109+
// For cross entropy loss back, we need one workgroup per row of logits (src1)
7110+
const uint32_t nr = ggml_nrows(src1);
7111+
if (nr > 262144) {
7112+
elements = { 512, 512, CEIL_DIV(nr, 262144) };
7113+
} else if (nr > 512) {
7114+
elements = { 512, CEIL_DIV(nr, 512), 1 };
7115+
} else {
7116+
elements = { nr, 1, 1 };
7117+
}
7118+
} break;
71087119
case GGML_OP_RMS_NORM:
71097120
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
71107121
break;
@@ -7829,8 +7840,7 @@ static void ggml_vk_cross_entropy_loss_back(ggml_backend_vk_context * ctx, vk_co
78297840
(uint32_t)nrows,
78307841
0.0f,
78317842
0.0f
7832-
}, dryrun);
7833-
7843+
}, dryrun);
78347844
}
78357845

78367846
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {

ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_back.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#define FLOAT_TYPE float
99

10-
layout(constant_id = 0) const uint BLOCK_SIZE = 256;
10+
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
1111
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1212

1313
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // Grad(scalar)

0 commit comments

Comments
 (0)