@@ -3395,9 +3395,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
33953395
33963396 ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
33973397
3398- ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {256 , 1, 1}, { device->subgroup_size }, 1);
3398+ ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1 , 1, 1}, { device->subgroup_size }, 1);
33993399
3400- ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {256 , 1, 1}, { device->subgroup_size }, 1);
3400+ ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1 , 1, 1}, { device->subgroup_size }, 1);
34013401
34023402 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
34033403 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
@@ -8198,7 +8198,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
81988198 case GGML_OP_L2_NORM:
81998199 case GGML_OP_SOFT_MAX:
82008200 case GGML_OP_SOFT_MAX_BACK:
8201- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
82028201 case GGML_OP_SUM_ROWS:
82038202 case GGML_OP_MEAN:
82048203 case GGML_OP_ARGMAX:
@@ -8212,6 +8211,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
82128211 elements = { nr, 1, 1 };
82138212 }
82148213 } break;
8214+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
8215+ {
8216+ // For cross entropy loss back, we need one workgroup per row of logits (src1)
8217+ const uint32_t nr = ggml_nrows(src1);
8218+ if (nr > 262144) {
8219+ elements = { 512, 512, CEIL_DIV(nr, 262144) };
8220+ } else if (nr > 512) {
8221+ elements = { 512, CEIL_DIV(nr, 512), 1 };
8222+ } else {
8223+ elements = { nr, 1, 1 };
8224+ }
8225+ } break;
82158226 case GGML_OP_RMS_NORM:
82168227 if (ctx->do_add_rms_partials) {
82178228 // Run one element per thread, 128 threads per workgroup
0 commit comments