Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 65 additions & 28 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1636,11 +1636,21 @@ ggml_opt_dataset_t common_opt_sft_dataset_init(
} else {
chat_template_source.assign(std::istreambuf_iterator<char>(tmpl_file), std::istreambuf_iterator<char>());
tmpl_file.close();
try {
chat_templates = common_chat_templates_init(llama_get_model(ctx), chat_template_source);
} catch (const std::exception & e) {
LOG_ERR("Warning: Failed to parse chat template '%s': %s\n", chat_template_path.c_str(), e.what());
}
}
}

try {
chat_templates = common_chat_templates_init(llama_get_model(ctx), chat_template_source);
if (chat_template_source.empty()) {
LOG_INF("Using model's built-in chat template\n");
} else {
LOG_INF("Using custom chat template from: %s\n", chat_template_path.c_str());
}
} catch (const std::exception & e) {
if (!chat_template_path.empty()) {
LOG_ERR("Warning: Failed to parse chat template '%s': %s\n", chat_template_path.c_str(), e.what());
} else {
LOG_ERR("Warning: Failed to initialize chat template: %s\n", e.what());
}
}

Expand Down Expand Up @@ -1756,33 +1766,60 @@ ggml_opt_dataset_t common_opt_sft_dataset_init(
std::vector<Span> assistant_spans;

{
size_t from = 0;
while (true) {
size_t open = render.find(START_AST, from);
if (open == std::string::npos) break;

// Include the role token ("assistant") and everything through the closing tag/newlines
size_t lo = open + START_TAG.size();
if (lo > render.size()) {
lo = render.size();
}
bool is_gemma = render.find("<start_of_turn>model\n") != std::string::npos;

if (is_gemma) {
const std::string GEMMA_START = "<start_of_turn>model\n";
const std::string GEMMA_END = "<end_of_turn>";

size_t from = 0;
while (true) {
size_t open = render.find(GEMMA_START, from);
if (open == std::string::npos) break;
size_t lo = open;
size_t close = render.find(GEMMA_END, lo);
if (close == std::string::npos) {
assistant_spans.push_back({lo, render.size()});
break;
}

size_t close = render.find(END_TAG, open + START_AST.size());
if (close == std::string::npos) {
assistant_spans.push_back({lo, render.size()});
break;
}
size_t hi = close + GEMMA_END.size();
if (hi < render.size() && render[hi] == '\n') {
hi++;
}
assistant_spans.push_back({lo, std::min(hi, render.size())});

size_t hi = close + END_TAG.size();
if (hi <= lo) {
lo = open;
hi = close + END_TAG.size();
from = hi;
}
} else {
size_t from = 0;
while (true) {
size_t open = render.find(START_AST, from);
if (open == std::string::npos) break;

// Include the role token ("assistant") and everything through the closing tag/newlines
size_t lo = open + START_TAG.size();
if (lo > render.size()) {
lo = render.size();
}

assistant_spans.push_back({lo, std::min(hi, render.size())});
size_t close = render.find(END_TAG, open + START_AST.size());
if (close == std::string::npos) {
assistant_spans.push_back({lo, render.size()});
break;
}

size_t next_from = hi;
from = next_from;
size_t hi = close + END_TAG.size();
if (hi <= lo) {
lo = open;
hi = close + END_TAG.size();
}

assistant_spans.push_back({lo, std::min(hi, render.size())});

size_t next_from = hi;
from = next_from;
}
}
}

Expand Down Expand Up @@ -1814,7 +1851,7 @@ ggml_opt_dataset_t common_opt_sft_dataset_init(
LOG_WRN("Warning: Conversation %zu has zero assistant tokens after masking\n", i);
continue;
}

all_tokenized_data.push_back(tokens_full);
all_assistant_masks.push_back(assistant_mask);
}
Expand Down
1 change: 0 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,6 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_geglu_back(
struct ggml_context * ctx,
struct ggml_tensor * grad,
struct ggml_tensor * x,
struct ggml_tensor * g);

// hardswish(x) = x * relu6(x + 3) / 6
Expand Down
20 changes: 4 additions & 16 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3517,14 +3517,12 @@ void ggml_compute_forward_silu_back(
static void ggml_compute_forward_geglu_back_f32(
const ggml_compute_params * params,
const struct ggml_tensor * grad,
const struct ggml_tensor * x,
const struct ggml_tensor * g,
struct ggml_tensor * dst) {

GGML_ASSERT(ggml_can_repeat(grad, dst));
GGML_ASSERT(ggml_are_same_shape(x, g));
GGML_ASSERT(ggml_are_same_shape(g, dst));
GGML_ASSERT(grad->type == GGML_TYPE_F32);
GGML_ASSERT(x->type == GGML_TYPE_F32);
GGML_ASSERT(g->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

Expand All @@ -3533,8 +3531,6 @@ static void ggml_compute_forward_geglu_back_f32(

const int nc = dst->ne[0];

GGML_ASSERT(nc % 2 == 0);

const size_t nb1 = dst->nb[1];
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];
Expand All @@ -3544,15 +3540,8 @@ static void ggml_compute_forward_geglu_back_f32(
for (int i1 = ith; i1 < dst->ne[1]; i1 += nth) {
float * dst_ptr = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
const float * grad_ptr = (const float *)((char *) grad->data + i3*grad->nb[3] + i2*grad->nb[2] + i1*grad->nb[1]);
const float * x_ptr = (const float *)((char *) x->data + i3*x->nb[3] + i2*x->nb[2] + i1*x->nb[1]);
const float * g_ptr = (const float *)((char *) g->data + i3*g->nb[3] + i2*g->nb[2] + i1*g->nb[1]);

const int half = nc / 2;
ggml_vec_gelu_f32(half, dst_ptr, g_ptr);
ggml_vec_mul_f32(half, dst_ptr, dst_ptr, grad_ptr);
float * temp = (float *)alloca(half * sizeof(float));
ggml_vec_gelu_backward_f32(half, temp, g_ptr, grad_ptr);
ggml_vec_mul_f32(half, dst_ptr + half, temp, x_ptr);
ggml_vec_gelu_backward_f32(nc, dst_ptr, g_ptr, grad_ptr);
}
}
}
Expand All @@ -3563,13 +3552,12 @@ void ggml_compute_forward_geglu_back(
ggml_tensor * dst) {

const struct ggml_tensor * grad = dst->src[0];
const struct ggml_tensor * x = dst->src[1];
const struct ggml_tensor * g = dst->src[2];
const struct ggml_tensor * g = dst->src[1];

switch (dst->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_geglu_back_f32(params, grad, x, g, dst);
ggml_compute_forward_geglu_back_f32(params, grad, g, dst);
} break;
default:
{
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9592,7 +9592,7 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
}

static void ggml_vk_geglu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GEGLU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GEGLU_BACK, { (uint32_t)ggml_nelements(dst), (uint32_t)dst->ne[0], 0.0f, 0.0f }, dryrun);
}

static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
Expand Down
36 changes: 10 additions & 26 deletions ggml/src/ggml-vulkan/vulkan-shaders/geglu_back.comp
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,19 @@
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer GRAD {A_TYPE data_grad[];};
layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
layout (binding = 1) readonly buffer GATE {B_TYPE data_gate[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};

float gelu(float x) {
float gelu_backward(float gate, float grad) {
const float c = 0.797884560802865; // sqrt(2/pi)
const float a = 0.044715;
const float inner = c * (x + a * x * x * x);
return 0.5 * x * (1.0 + tanh(inner));
}

float gelu_derivative(float x) {
const float c = 0.797884560802865; // sqrt(2/pi)
const float a = 0.044715;
const float x_squared = x * x;
const float x_cubed = x_squared * x;
const float inner = c * (x + a * x_cubed);
const float x_squared = gate * gate;
const float x_cubed = x_squared * gate;
const float inner = c * (gate + a * x_cubed);
const float tanh_val = tanh(inner);
const float sech2_val = 1.0 - tanh_val * tanh_val;
const float dtanh_dx = c * (1.0 + 3.0 * a * x_squared) * sech2_val;
return 0.5 * (1.0 + tanh_val + x * dtanh_dx);
return grad * 0.5 * (1.0 + tanh_val + gate * dtanh_dx);
}

void main() {
Expand All @@ -37,17 +30,8 @@ void main() {
return;
}

const uint half_size = p.KX / 2;

if (i < half_size) {
const float grad_val = float(data_grad[i]);
const float g_val = float(data_x[i + half_size]);
data_d[i] = D_TYPE(grad_val * gelu(g_val));
} else {
const uint idx = i - half_size;
const float grad_val = float(data_grad[idx]);
const float x_val = float(data_x[idx]);
const float g_val = float(data_x[i]);
data_d[i] = D_TYPE(grad_val * x_val * gelu_derivative(g_val));
}
const float grad_val = float(data_grad[i]);
const float gate_val = float(data_gate[i]);

data_d[i] = D_TYPE(gelu_backward(gate_val, grad_val));
}
12 changes: 5 additions & 7 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2696,14 +2696,12 @@ struct ggml_tensor * ggml_silu_back(
struct ggml_tensor * ggml_geglu_back(
struct ggml_context * ctx,
struct ggml_tensor * grad,
struct ggml_tensor * x,
struct ggml_tensor * g) {
struct ggml_tensor * result = ggml_dup_tensor(ctx, x);
struct ggml_tensor * result = ggml_dup_tensor(ctx, g);

result->op = GGML_OP_GEGLU_BACK;
result->src[0] = grad;
result->src[1] = x;
result->src[2] = g;
result->src[1] = g;

return result;
}
Expand Down Expand Up @@ -6488,11 +6486,11 @@ static void ggml_compute_backward(
case GGML_GLU_OP_GEGLU: {
if (src0_needs_grads) {
GGML_ASSERT(src1 && "backward pass only implemented for split geglu");
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_gelu(ctx, src1)));
struct ggml_tensor * grad_mul_src1 = ggml_mul(ctx, grad, src1);
ggml_add_or_set(ctx, cgraph, isrc0, ggml_geglu_back(ctx, grad_mul_src1, src0));
}
if (src1_needs_grads) {
struct ggml_tensor * grad_mul_src0 = ggml_mul(ctx, grad, src0);
ggml_add_or_set(ctx, cgraph, isrc1, ggml_geglu_back(ctx, grad_mul_src0, src1, src1));
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_gelu(ctx, src0)));
}
} break;
default: {
Expand Down
35 changes: 35 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,39 @@ struct test_swiglu_oai : public test_case {
}
};

// GGML_OP_GEGLU_BACK
struct test_geglu_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;

std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
}

test_geglu_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * gate = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(gate, "gate");

ggml_tensor * grad = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(grad, "grad");

ggml_tensor * out = ggml_geglu_back(ctx, grad, gate);
ggml_set_name(out, "out");

return out;
}

bool grad_precise() override {
return true;
}
};

// GGML_OP_GET_ROWS
struct test_get_rows : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -5629,6 +5662,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}

test_cases.emplace_back(new test_geglu_back());

for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_get_rows(type, 300*256, 5, 4, 1, 2, false));
test_cases.emplace_back(new test_get_rows(type, 256, 80000, 70000, 2, 1, false));
Expand Down
Loading