@@ -2916,7 +2916,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2916
2916
2917
2917
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2918
2918
2919
- ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {256 , 1, 1}, { device->subgroup_size }, 1);
2919
+ ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1 , 1, 1}, { device->subgroup_size }, 1);
2920
2920
2921
2921
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2922
2922
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
@@ -7092,7 +7092,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7092
7092
case GGML_OP_L2_NORM:
7093
7093
case GGML_OP_SOFT_MAX:
7094
7094
case GGML_OP_SOFT_MAX_BACK:
7095
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
7096
7095
case GGML_OP_SUM_ROWS:
7097
7096
case GGML_OP_ARGMAX:
7098
7097
{
@@ -7105,6 +7104,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7105
7104
elements = { nr, 1, 1 };
7106
7105
}
7107
7106
} break;
7107
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
7108
+ {
7109
+ // For cross entropy loss back, we need one workgroup per row of logits (src1)
7110
+ const uint32_t nr = ggml_nrows(src1);
7111
+ if (nr > 262144) {
7112
+ elements = { 512, 512, CEIL_DIV(nr, 262144) };
7113
+ } else if (nr > 512) {
7114
+ elements = { 512, CEIL_DIV(nr, 512), 1 };
7115
+ } else {
7116
+ elements = { nr, 1, 1 };
7117
+ }
7118
+ } break;
7108
7119
case GGML_OP_RMS_NORM:
7109
7120
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7110
7121
break;
@@ -7829,8 +7840,7 @@ static void ggml_vk_cross_entropy_loss_back(ggml_backend_vk_context * ctx, vk_co
7829
7840
(uint32_t)nrows,
7830
7841
0.0f,
7831
7842
0.0f
7832
- }, dryrun);
7833
-
7843
+ }, dryrun);
7834
7844
}
7835
7845
7836
7846
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
0 commit comments