Skip to content

Commit 00f8459

Browse files
makaveli10zoq
authored andcommitted
vulkan: Fix cross-entropy-loss-back dispatch size and wg denominator
Signed-off-by: vineet <[email protected]>
1 parent 5c34315 commit 00f8459

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3395,9 +3395,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
33953395

33963396
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
33973397

3398-
ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {256, 1, 1}, { device->subgroup_size }, 1);
3398+
ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
33993399

3400-
ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {256, 1, 1}, { device->subgroup_size }, 1);
3400+
ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
34013401

34023402
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
34033403
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
@@ -8198,7 +8198,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
81988198
case GGML_OP_L2_NORM:
81998199
case GGML_OP_SOFT_MAX:
82008200
case GGML_OP_SOFT_MAX_BACK:
8201-
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
82028201
case GGML_OP_SUM_ROWS:
82038202
case GGML_OP_MEAN:
82048203
case GGML_OP_ARGMAX:
@@ -8212,6 +8211,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
82128211
elements = { nr, 1, 1 };
82138212
}
82148213
} break;
8214+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
8215+
{
8216+
// For cross entropy loss back, we need one workgroup per row of logits (src1)
8217+
const uint32_t nr = ggml_nrows(src1);
8218+
if (nr > 262144) {
8219+
elements = { 512, 512, CEIL_DIV(nr, 262144) };
8220+
} else if (nr > 512) {
8221+
elements = { 512, CEIL_DIV(nr, 512), 1 };
8222+
} else {
8223+
elements = { nr, 1, 1 };
8224+
}
8225+
} break;
82158226
case GGML_OP_RMS_NORM:
82168227
if (ctx->do_add_rms_partials) {
82178228
// Run one element per thread, 128 threads per workgroup

ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_back.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#define FLOAT_TYPE float
99

10-
layout(constant_id = 0) const uint BLOCK_SIZE = 256;
10+
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
1111
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1212

1313
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // Grad(scalar)

0 commit comments

Comments
 (0)