Skip to content

Commit ca99485

Browse files
zoqmakaveli10
authored andcommitted
Add GEGLU backward (Vulkan) to enable Gemma training.
1 parent a71fc37 commit ca99485

File tree

9 files changed

+211
-5
lines changed

9 files changed

+211
-5
lines changed

ggml/include/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ extern "C" {
457457
GGML_OP_REPEAT_BACK,
458458
GGML_OP_CONCAT,
459459
GGML_OP_SILU_BACK,
460+
GGML_OP_GEGLU_BACK,
460461
GGML_OP_NORM, // normalize
461462
GGML_OP_RMS_NORM,
462463
GGML_OP_RMS_NORM_BACK,
@@ -1097,6 +1098,12 @@ extern "C" {
10971098
struct ggml_tensor * a,
10981099
struct ggml_tensor * b);
10991100

1101+
GGML_API struct ggml_tensor * ggml_geglu_back(
1102+
struct ggml_context * ctx,
1103+
struct ggml_tensor * grad,
1104+
struct ggml_tensor * x,
1105+
struct ggml_tensor * g);
1106+
11001107
// hardswish(x) = x * relu6(x + 3) / 6
11011108
GGML_API struct ggml_tensor * ggml_hardswish(
11021109
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,6 +1746,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17461746
{
17471747
ggml_compute_forward_silu_back(params, tensor);
17481748
} break;
1749+
case GGML_OP_GEGLU_BACK:
1750+
{
1751+
ggml_compute_forward_geglu_back(params, tensor);
1752+
} break;
17491753
case GGML_OP_NORM:
17501754
{
17511755
ggml_compute_forward_norm(params, tensor);
@@ -2182,6 +2186,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21822186
}
21832187
break;
21842188
case GGML_OP_SILU_BACK:
2189+
case GGML_OP_GEGLU_BACK:
21852190
case GGML_OP_MUL:
21862191
case GGML_OP_DIV:
21872192
case GGML_OP_NORM:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3185,6 +3185,71 @@ void ggml_compute_forward_silu_back(
31853185
}
31863186
}
31873187

3188+
static void ggml_compute_forward_geglu_back_f32(
3189+
const ggml_compute_params * params,
3190+
const struct ggml_tensor * grad,
3191+
const struct ggml_tensor * x,
3192+
const struct ggml_tensor * g,
3193+
struct ggml_tensor * dst) {
3194+
3195+
GGML_ASSERT(ggml_can_repeat(grad, dst));
3196+
GGML_ASSERT(ggml_are_same_shape(x, g));
3197+
GGML_ASSERT(grad->type == GGML_TYPE_F32);
3198+
GGML_ASSERT(x->type == GGML_TYPE_F32);
3199+
GGML_ASSERT(g->type == GGML_TYPE_F32);
3200+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
3201+
3202+
const int ith = params->ith;
3203+
const int nth = params->nth;
3204+
3205+
const int nc = dst->ne[0];
3206+
3207+
GGML_ASSERT(nc % 2 == 0);
3208+
3209+
const size_t nb1 = dst->nb[1];
3210+
const size_t nb2 = dst->nb[2];
3211+
const size_t nb3 = dst->nb[3];
3212+
3213+
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
3214+
for (int i2 = 0; i2 < dst->ne[2]; i2++) {
3215+
for (int i1 = ith; i1 < dst->ne[1]; i1 += nth) {
3216+
float * dst_ptr = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
3217+
const float * grad_ptr = (const float *)((char *) grad->data + i3*grad->nb[3] + i2*grad->nb[2] + i1*grad->nb[1]);
3218+
const float * x_ptr = (const float *)((char *) x->data + i3*x->nb[3] + i2*x->nb[2] + i1*x->nb[1]);
3219+
const float * g_ptr = (const float *)((char *) g->data + i3*g->nb[3] + i2*g->nb[2] + i1*g->nb[1]);
3220+
3221+
const int half = nc / 2;
3222+
ggml_vec_gelu_f32(half, dst_ptr, g_ptr);
3223+
ggml_vec_mul_f32(half, dst_ptr, dst_ptr, grad_ptr);
3224+
float * temp = (float *)alloca(half * sizeof(float));
3225+
ggml_vec_gelu_backward_f32(half, temp, g_ptr, grad_ptr);
3226+
ggml_vec_mul_f32(half, dst_ptr + half, temp, x_ptr);
3227+
}
3228+
}
3229+
}
3230+
}
3231+
3232+
void ggml_compute_forward_geglu_back(
3233+
const ggml_compute_params * params,
3234+
ggml_tensor * dst) {
3235+
3236+
const struct ggml_tensor * grad = dst->src[0];
3237+
const struct ggml_tensor * x = dst->src[1];
3238+
const struct ggml_tensor * g = dst->src[2];
3239+
3240+
switch (dst->type) {
3241+
case GGML_TYPE_F32:
3242+
{
3243+
ggml_compute_forward_geglu_back_f32(params, grad, x, g, dst);
3244+
} break;
3245+
default:
3246+
{
3247+
GGML_ABORT("fatal error");
3248+
}
3249+
}
3250+
}
3251+
3252+
31883253
// ggml_compute_forward_reglu
31893254

31903255
static void ggml_compute_forward_reglu_f32(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void ggml_compute_forward_repeat(const struct ggml_compute_params * params, stru
4040
void ggml_compute_forward_repeat_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4141
void ggml_compute_forward_concat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4242
void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
43+
void ggml_compute_forward_geglu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4344
void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4445
void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4546
void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-cpu/vec.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,32 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
944944
}
945945
}
946946

947+
inline static float ggml_gelu_backward_f32(float x, float dy) {
948+
const float tanh_arg = SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x);
949+
const float tanh_val = tanhf(tanh_arg);
950+
const float sech2_val = 1.0f - tanh_val * tanh_val;
951+
const float dtanh_dx = SQRT_2_OVER_PI * (1.0f + 3.0f * GELU_COEF_A * x * x) * sech2_val;
952+
return dy * 0.5f * (1.0f + tanh_val + x * dtanh_dx);
953+
}
954+
955+
inline static void ggml_vec_gelu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
956+
for (int i = 0; i < n; ++i) {
957+
dx[i] = ggml_gelu_backward_f32(x[i], dy[i]);
958+
}
959+
}
960+
961+
inline static void ggml_vec_gelu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) {
962+
for (int i = 0; i < n; ++i) {
963+
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
964+
float tanh_arg = SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi);
965+
float tanh_val = tanhf(tanh_arg);
966+
float sech2_val = 1.0f - tanh_val * tanh_val;
967+
float dtanh_dx = SQRT_2_OVER_PI * (1.0f + 3.0f * GELU_COEF_A * xi * xi) * sech2_val;
968+
969+
dx[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(dy[i]) * 0.5f * (1.0f + tanh_val + xi * dtanh_dx));
970+
}
971+
}
972+
947973
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
948974
for (int i = 0; i < n; ++i) {
949975
y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ struct vk_device_struct {
463463

464464
vk_pipeline pipeline_leaky_relu_f32;
465465
vk_pipeline pipeline_silu_back_f32;
466+
vk_pipeline pipeline_geglu_back_f32;
466467
vk_pipeline pipeline_diag_mask_inf_f32;
467468
vk_pipeline pipeline_cross_entropy_loss_back_f32;
468469
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
@@ -2914,6 +2915,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29142915
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
29152916
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
29162917

2918+
ggml_vk_create_pipeline(device, device->pipeline_geglu_back_f32, "geglu_back_f32", geglu_back_f32_len, geglu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2919+
29172920
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
29182921

29192922
ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -6628,6 +6631,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
66286631
return ctx->device->pipeline_silu_back_f32;
66296632
}
66306633
return nullptr;
6634+
case GGML_OP_GEGLU_BACK:
6635+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6636+
return ctx->device->pipeline_geglu_back_f32;
6637+
}
6638+
return nullptr;
66316639
case GGML_OP_NORM:
66326640
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
66336641
return ctx->device->pipeline_norm_f32;
@@ -7761,6 +7769,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
77617769
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
77627770
}
77637771

7772+
static void ggml_vk_geglu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7773+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GEGLU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7774+
}
7775+
77647776
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
77657777
float * op_params = (float *)dst->op_params;
77667778

@@ -7835,12 +7847,12 @@ static void ggml_vk_cross_entropy_loss_back(ggml_backend_vk_context * ctx, vk_co
78357847
const int64_t nclasses = src1->ne[0];
78367848
const int64_t nrows = ggml_nrows(src1);
78377849

7838-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_CROSS_ENTROPY_LOSS_BACK, {
7850+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_CROSS_ENTROPY_LOSS_BACK, {
78397851
(uint32_t)nclasses,
78407852
(uint32_t)nrows,
78417853
0.0f,
78427854
0.0f
7843-
}, dryrun);
7855+
}, dryrun);
78447856
}
78457857

78467858
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9112,6 +9124,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91129124
case GGML_OP_CONT:
91139125
case GGML_OP_DUP:
91149126
case GGML_OP_SILU_BACK:
9127+
case GGML_OP_GEGLU_BACK:
91159128
case GGML_OP_NORM:
91169129
case GGML_OP_GROUP_NORM:
91179130
case GGML_OP_RMS_NORM:
@@ -9181,6 +9194,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91819194
case GGML_OP_CONT:
91829195
case GGML_OP_DUP:
91839196
case GGML_OP_SILU_BACK:
9197+
case GGML_OP_GEGLU_BACK:
91849198
case GGML_OP_NORM:
91859199
case GGML_OP_GROUP_NORM:
91869200
case GGML_OP_RMS_NORM:
@@ -9303,6 +9317,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93039317
case GGML_OP_SILU_BACK:
93049318
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
93059319

9320+
break;
9321+
case GGML_OP_GEGLU_BACK:
9322+
ggml_vk_geglu_back(ctx, compute_ctx, src0, src1, node, dryrun);
9323+
93069324
break;
93079325
case GGML_OP_NORM:
93089326
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -9362,7 +9380,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93629380
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
93639381

93649382
break;
9365-
9383+
93669384
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
93679385
ggml_vk_cross_entropy_loss_back(ctx, compute_ctx, src0, src1, src2, node, dryrun);
93689386

@@ -9524,6 +9542,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
95249542
case GGML_OP_CONT:
95259543
case GGML_OP_DUP:
95269544
case GGML_OP_SILU_BACK:
9545+
case GGML_OP_GEGLU_BACK:
95279546
case GGML_OP_NORM:
95289547
case GGML_OP_GROUP_NORM:
95299548
case GGML_OP_RMS_NORM:
@@ -10693,6 +10712,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1069310712
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
1069410713
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1069510714
case GGML_OP_SILU_BACK:
10715+
case GGML_OP_GEGLU_BACK:
1069610716
case GGML_OP_RMS_NORM_BACK:
1069710717
case GGML_OP_SQR:
1069810718
case GGML_OP_SIN:
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer GRAD {A_TYPE data_grad[];};
11+
layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
12+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
13+
14+
float gelu(float x) {
15+
const float c = 0.797884560802865; // sqrt(2/pi)
16+
const float a = 0.044715;
17+
const float inner = c * (x + a * x * x * x);
18+
return 0.5 * x * (1.0 + tanh(inner));
19+
}
20+
21+
float gelu_derivative(float x) {
22+
const float c = 0.797884560802865; // sqrt(2/pi)
23+
const float a = 0.044715;
24+
const float x_squared = x * x;
25+
const float x_cubed = x_squared * x;
26+
const float inner = c * (x + a * x_cubed);
27+
const float tanh_val = tanh(inner);
28+
const float sech2_val = 1.0 - tanh_val * tanh_val;
29+
const float dtanh_dx = c * (1.0 + 3.0 * a * x_squared) * sech2_val;
30+
return 0.5 * (1.0 + tanh_val + x * dtanh_dx);
31+
}
32+
33+
void main() {
34+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
35+
36+
if (i >= p.KX) {
37+
return;
38+
}
39+
40+
const uint half_size = p.KX / 2;
41+
42+
if (i < half_size) {
43+
const float grad_val = float(data_grad[i]);
44+
const float g_val = float(data_x[i + half_size]);
45+
data_d[i] = D_TYPE(grad_val * gelu(g_val));
46+
} else {
47+
const uint idx = i - half_size;
48+
const float grad_val = float(data_grad[idx]);
49+
const float x_val = float(data_x[idx]);
50+
const float g_val = float(data_x[i]);
51+
data_d[i] = D_TYPE(grad_val * x_val * gelu_derivative(g_val));
52+
}
53+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,7 @@ void process_shaders() {
610610

611611
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
612612
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
613+
string_to_spv("geglu_back_f32", "geglu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
613614

614615
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
615616

ggml/src/ggml.c

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
936936
"REPEAT_BACK",
937937
"CONCAT",
938938
"SILU_BACK",
939+
"GEGLU_BACK",
939940
"NORM",
940941
"RMS_NORM",
941942
"RMS_NORM_BACK",
@@ -1010,7 +1011,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10101011
"GLU",
10111012
};
10121013

1013-
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
1014+
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
10141015

10151016
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10161017
"none",
@@ -1036,6 +1037,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10361037
"repeat_back(x)",
10371038
"concat(x, y)",
10381039
"silu_back(x)",
1040+
"geglu_back(x)",
10391041
"norm(x)",
10401042
"rms_norm(x)",
10411043
"rms_norm_back(x)",
@@ -1110,7 +1112,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11101112
"glu(x)",
11111113
};
11121114

1113-
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
1115+
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
11141116

11151117
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11161118

@@ -2632,6 +2634,22 @@ struct ggml_tensor * ggml_silu_back(
26322634
return result;
26332635
}
26342636

2637+
// ggml_geglu_back
2638+
struct ggml_tensor * ggml_geglu_back(
2639+
struct ggml_context * ctx,
2640+
struct ggml_tensor * grad,
2641+
struct ggml_tensor * x,
2642+
struct ggml_tensor * g) {
2643+
struct ggml_tensor * result = ggml_dup_tensor(ctx, x);
2644+
2645+
result->op = GGML_OP_GEGLU_BACK;
2646+
result->src[0] = grad;
2647+
result->src[1] = x;
2648+
result->src[2] = g;
2649+
2650+
return result;
2651+
}
2652+
26352653
// ggml hardswish
26362654

26372655
struct ggml_tensor * ggml_hardswish(
@@ -6123,6 +6141,16 @@ static void ggml_compute_backward(
61236141
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
61246142
}
61256143
} break;
6144+
case GGML_GLU_OP_GEGLU: {
6145+
if (src0_needs_grads) {
6146+
GGML_ASSERT(src1 && "backward pass only implemented for split geglu");
6147+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_gelu(ctx, src1)));
6148+
}
6149+
if (src1_needs_grads) {
6150+
struct ggml_tensor * grad_mul_src0 = ggml_mul(ctx, grad, src0);
6151+
ggml_add_or_set(ctx, cgraph, isrc1, ggml_geglu_back(ctx, grad_mul_src0, src1, src1));
6152+
}
6153+
} break;
61266154
default: {
61276155
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
61286156
} //break;

0 commit comments

Comments
 (0)