Skip to content

Commit d0acca3

Browse files
committed
Add GEGLU backward (Vulkan) to enable Gemma training.
1 parent cd8c3e6 commit d0acca3

File tree

9 files changed

+208
-2
lines changed

9 files changed

+208
-2
lines changed

ggml/include/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ extern "C" {
479479
GGML_OP_REPEAT_BACK,
480480
GGML_OP_CONCAT,
481481
GGML_OP_SILU_BACK,
482+
GGML_OP_GEGLU_BACK,
482483
GGML_OP_NORM, // normalize
483484
GGML_OP_RMS_NORM,
484485
GGML_OP_RMS_NORM_BACK,
@@ -1130,6 +1131,12 @@ extern "C" {
11301131
struct ggml_tensor * a,
11311132
struct ggml_tensor * b);
11321133

1134+
GGML_API struct ggml_tensor * ggml_geglu_back(
1135+
struct ggml_context * ctx,
1136+
struct ggml_tensor * grad,
1137+
struct ggml_tensor * x,
1138+
struct ggml_tensor * g);
1139+
11331140
// hardswish(x) = x * relu6(x + 3) / 6
11341141
GGML_API struct ggml_tensor * ggml_hardswish(
11351142
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17591759
{
17601760
ggml_compute_forward_silu_back(params, tensor);
17611761
} break;
1762+
case GGML_OP_GEGLU_BACK:
1763+
{
1764+
ggml_compute_forward_geglu_back(params, tensor);
1765+
} break;
17621766
case GGML_OP_NORM:
17631767
{
17641768
ggml_compute_forward_norm(params, tensor);
@@ -2210,6 +2214,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22102214
}
22112215
break;
22122216
case GGML_OP_SILU_BACK:
2217+
case GGML_OP_GEGLU_BACK:
22132218
case GGML_OP_MUL:
22142219
case GGML_OP_DIV:
22152220
case GGML_OP_NORM:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3420,6 +3420,71 @@ void ggml_compute_forward_silu_back(
34203420
}
34213421
}
34223422

3423+
static void ggml_compute_forward_geglu_back_f32(
3424+
const ggml_compute_params * params,
3425+
const struct ggml_tensor * grad,
3426+
const struct ggml_tensor * x,
3427+
const struct ggml_tensor * g,
3428+
struct ggml_tensor * dst) {
3429+
3430+
GGML_ASSERT(ggml_can_repeat(grad, dst));
3431+
GGML_ASSERT(ggml_are_same_shape(x, g));
3432+
GGML_ASSERT(grad->type == GGML_TYPE_F32);
3433+
GGML_ASSERT(x->type == GGML_TYPE_F32);
3434+
GGML_ASSERT(g->type == GGML_TYPE_F32);
3435+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
3436+
3437+
const int ith = params->ith;
3438+
const int nth = params->nth;
3439+
3440+
const int nc = dst->ne[0];
3441+
3442+
GGML_ASSERT(nc % 2 == 0);
3443+
3444+
const size_t nb1 = dst->nb[1];
3445+
const size_t nb2 = dst->nb[2];
3446+
const size_t nb3 = dst->nb[3];
3447+
3448+
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
3449+
for (int i2 = 0; i2 < dst->ne[2]; i2++) {
3450+
for (int i1 = ith; i1 < dst->ne[1]; i1 += nth) {
3451+
float * dst_ptr = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
3452+
const float * grad_ptr = (const float *)((char *) grad->data + i3*grad->nb[3] + i2*grad->nb[2] + i1*grad->nb[1]);
3453+
const float * x_ptr = (const float *)((char *) x->data + i3*x->nb[3] + i2*x->nb[2] + i1*x->nb[1]);
3454+
const float * g_ptr = (const float *)((char *) g->data + i3*g->nb[3] + i2*g->nb[2] + i1*g->nb[1]);
3455+
3456+
const int half = nc / 2;
3457+
ggml_vec_gelu_f32(half, dst_ptr, g_ptr);
3458+
ggml_vec_mul_f32(half, dst_ptr, dst_ptr, grad_ptr);
3459+
float * temp = (float *)alloca(half * sizeof(float));
3460+
ggml_vec_gelu_backward_f32(half, temp, g_ptr, grad_ptr);
3461+
ggml_vec_mul_f32(half, dst_ptr + half, temp, x_ptr);
3462+
}
3463+
}
3464+
}
3465+
}
3466+
3467+
void ggml_compute_forward_geglu_back(
3468+
const ggml_compute_params * params,
3469+
ggml_tensor * dst) {
3470+
3471+
const struct ggml_tensor * grad = dst->src[0];
3472+
const struct ggml_tensor * x = dst->src[1];
3473+
const struct ggml_tensor * g = dst->src[2];
3474+
3475+
switch (dst->type) {
3476+
case GGML_TYPE_F32:
3477+
{
3478+
ggml_compute_forward_geglu_back_f32(params, grad, x, g, dst);
3479+
} break;
3480+
default:
3481+
{
3482+
GGML_ABORT("fatal error");
3483+
}
3484+
}
3485+
}
3486+
3487+
34233488
// ggml_compute_forward_reglu
34243489

34253490
static void ggml_compute_forward_reglu_f32(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ void ggml_compute_forward_repeat(const struct ggml_compute_params * params, stru
4141
void ggml_compute_forward_repeat_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4242
void ggml_compute_forward_concat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4343
void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
44+
void ggml_compute_forward_geglu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4445
void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4546
void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4647
void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-cpu/vec.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,32 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
13081308
}
13091309
}
13101310

1311+
inline static float ggml_gelu_backward_f32(float x, float dy) {
1312+
const float tanh_arg = SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x);
1313+
const float tanh_val = tanhf(tanh_arg);
1314+
const float sech2_val = 1.0f - tanh_val * tanh_val;
1315+
const float dtanh_dx = SQRT_2_OVER_PI * (1.0f + 3.0f * GELU_COEF_A * x * x) * sech2_val;
1316+
return dy * 0.5f * (1.0f + tanh_val + x * dtanh_dx);
1317+
}
1318+
1319+
inline static void ggml_vec_gelu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
1320+
for (int i = 0; i < n; ++i) {
1321+
dx[i] = ggml_gelu_backward_f32(x[i], dy[i]);
1322+
}
1323+
}
1324+
1325+
inline static void ggml_vec_gelu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) {
1326+
for (int i = 0; i < n; ++i) {
1327+
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
1328+
float tanh_arg = SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi);
1329+
float tanh_val = tanhf(tanh_arg);
1330+
float sech2_val = 1.0f - tanh_val * tanh_val;
1331+
float dtanh_dx = SQRT_2_OVER_PI * (1.0f + 3.0f * GELU_COEF_A * xi * xi) * sech2_val;
1332+
1333+
dx[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(dy[i]) * 0.5f * (1.0f + tanh_val + xi * dtanh_dx));
1334+
}
1335+
}
1336+
13111337
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
13121338
for (int i = 0; i < n; ++i) {
13131339
y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ struct vk_device_struct {
541541

542542
vk_pipeline pipeline_leaky_relu_f32;
543543
vk_pipeline pipeline_silu_back_f32;
544+
vk_pipeline pipeline_geglu_back_f32;
544545
vk_pipeline pipeline_diag_mask_inf_f32;
545546
vk_pipeline pipeline_cross_entropy_loss_back_f32;
546547
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
@@ -3393,6 +3394,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
33933394
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
33943395
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
33953396

3397+
ggml_vk_create_pipeline(device, device->pipeline_geglu_back_f32, "geglu_back_f32", geglu_back_f32_len, geglu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3398+
33963399
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
33973400

33983401
ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -7634,6 +7637,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
76347637
return ctx->device->pipeline_silu_back_f32;
76357638
}
76367639
return nullptr;
7640+
case GGML_OP_GEGLU_BACK:
7641+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7642+
return ctx->device->pipeline_geglu_back_f32;
7643+
}
7644+
return nullptr;
76377645
case GGML_OP_NORM:
76387646
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
76397647
return ctx->device->pipeline_norm_f32;
@@ -9064,6 +9072,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
90649072
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
90659073
}
90669074

9075+
static void ggml_vk_geglu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
9076+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GEGLU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
9077+
}
9078+
90679079
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
90689080
float * op_params = (float *)dst->op_params;
90699081

@@ -10585,6 +10597,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1058510597
case GGML_OP_CONT:
1058610598
case GGML_OP_DUP:
1058710599
case GGML_OP_SILU_BACK:
10600+
case GGML_OP_GEGLU_BACK:
1058810601
case GGML_OP_NORM:
1058910602
case GGML_OP_GROUP_NORM:
1059010603
case GGML_OP_RMS_NORM:
@@ -10658,6 +10671,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1065810671
case GGML_OP_CONT:
1065910672
case GGML_OP_DUP:
1066010673
case GGML_OP_SILU_BACK:
10674+
case GGML_OP_GEGLU_BACK:
1066110675
case GGML_OP_NORM:
1066210676
case GGML_OP_GROUP_NORM:
1066310677
case GGML_OP_RMS_NORM:
@@ -10872,6 +10886,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1087210886
case GGML_OP_SILU_BACK:
1087310887
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
1087410888

10889+
break;
10890+
case GGML_OP_GEGLU_BACK:
10891+
ggml_vk_geglu_back(ctx, compute_ctx, src0, src1, node, dryrun);
10892+
1087510893
break;
1087610894
case GGML_OP_NORM:
1087710895
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -11116,6 +11134,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1111611134
case GGML_OP_CONT:
1111711135
case GGML_OP_DUP:
1111811136
case GGML_OP_SILU_BACK:
11137+
case GGML_OP_GEGLU_BACK:
1111911138
case GGML_OP_NORM:
1112011139
case GGML_OP_GROUP_NORM:
1112111140
case GGML_OP_RMS_NORM:
@@ -12544,6 +12563,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1254412563
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
1254512564
op->type == GGML_TYPE_F32;
1254612565
case GGML_OP_SILU_BACK:
12566+
case GGML_OP_GEGLU_BACK:
1254712567
case GGML_OP_RMS_NORM_BACK:
1254812568
case GGML_OP_SQR:
1254912569
case GGML_OP_SQRT:
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer GRAD {A_TYPE data_grad[];};
11+
layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
12+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
13+
14+
float gelu(float x) {
15+
const float c = 0.797884560802865; // sqrt(2/pi)
16+
const float a = 0.044715;
17+
const float inner = c * (x + a * x * x * x);
18+
return 0.5 * x * (1.0 + tanh(inner));
19+
}
20+
21+
float gelu_derivative(float x) {
22+
const float c = 0.797884560802865; // sqrt(2/pi)
23+
const float a = 0.044715;
24+
const float x_squared = x * x;
25+
const float x_cubed = x_squared * x;
26+
const float inner = c * (x + a * x_cubed);
27+
const float tanh_val = tanh(inner);
28+
const float sech2_val = 1.0 - tanh_val * tanh_val;
29+
const float dtanh_dx = c * (1.0 + 3.0 * a * x_squared) * sech2_val;
30+
return 0.5 * (1.0 + tanh_val + x * dtanh_dx);
31+
}
32+
33+
void main() {
34+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
35+
36+
if (i >= p.KX) {
37+
return;
38+
}
39+
40+
const uint half_size = p.KX / 2;
41+
42+
if (i < half_size) {
43+
const float grad_val = float(data_grad[i]);
44+
const float g_val = float(data_x[i + half_size]);
45+
data_d[i] = D_TYPE(grad_val * gelu(g_val));
46+
} else {
47+
const uint idx = i - half_size;
48+
const float grad_val = float(data_grad[idx]);
49+
const float x_val = float(data_x[idx]);
50+
const float g_val = float(data_x[i]);
51+
data_d[i] = D_TYPE(grad_val * x_val * gelu_derivative(g_val));
52+
}
53+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ void process_shaders() {
684684

685685
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
686686
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
687+
string_to_spv("geglu_back_f32", "geglu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
687688

688689
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
689690

ggml/src/ggml.c

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
942942
"REPEAT_BACK",
943943
"CONCAT",
944944
"SILU_BACK",
945+
"GEGLU_BACK",
945946
"NORM",
946947
"RMS_NORM",
947948
"RMS_NORM_BACK",
@@ -1019,7 +1020,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10191020
"GLU",
10201021
};
10211022

1022-
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
1023+
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
10231024

10241025
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10251026
"none",
@@ -1046,6 +1047,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10461047
"repeat_back(x)",
10471048
"concat(x, y)",
10481049
"silu_back(x)",
1050+
"geglu_back(x)",
10491051
"norm(x)",
10501052
"rms_norm(x)",
10511053
"rms_norm_back(x)",
@@ -1123,7 +1125,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11231125
"glu(x)",
11241126
};
11251127

1126-
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
1128+
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
11271129

11281130
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11291131

@@ -2667,6 +2669,22 @@ struct ggml_tensor * ggml_silu_back(
26672669
return result;
26682670
}
26692671

2672+
// ggml_geglu_back
2673+
struct ggml_tensor * ggml_geglu_back(
2674+
struct ggml_context * ctx,
2675+
struct ggml_tensor * grad,
2676+
struct ggml_tensor * x,
2677+
struct ggml_tensor * g) {
2678+
struct ggml_tensor * result = ggml_dup_tensor(ctx, x);
2679+
2680+
result->op = GGML_OP_GEGLU_BACK;
2681+
result->src[0] = grad;
2682+
result->src[1] = x;
2683+
result->src[2] = g;
2684+
2685+
return result;
2686+
}
2687+
26702688
// ggml hardswish
26712689

26722690
struct ggml_tensor * ggml_hardswish(
@@ -6389,6 +6407,16 @@ static void ggml_compute_backward(
63896407
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
63906408
}
63916409
} break;
6410+
case GGML_GLU_OP_GEGLU: {
6411+
if (src0_needs_grads) {
6412+
GGML_ASSERT(src1 && "backward pass only implemented for split geglu");
6413+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_gelu(ctx, src1)));
6414+
}
6415+
if (src1_needs_grads) {
6416+
struct ggml_tensor * grad_mul_src0 = ggml_mul(ctx, grad, src0);
6417+
ggml_add_or_set(ctx, cgraph, isrc1, ggml_geglu_back(ctx, grad_mul_src0, src1, src1));
6418+
}
6419+
} break;
63926420
default: {
63936421
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
63946422
} //break;

0 commit comments

Comments
 (0)