Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
cparams.training = params.training;
cparams.kv_unified = params.kv_unified;

cparams.type_k = params.cache_type_k;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ struct common_params {
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool training = false; // enable training mode (affects LoRA K/V gradient flow)
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)

bool single_turn = false; // single turn chat conversation
Expand Down
1 change: 1 addition & 0 deletions examples/training/finetune-lora.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ int main(int argc, char ** argv) {
common_init();
llama_backend_init();
llama_numa_init(params.numa);
params.training = true;

common_init_result llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ extern "C" {
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
bool training; // if true, we're in training mode (affects LoRA K/V gradient flow)
};

// model quantization parameters
Expand Down
2 changes: 2 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ llama_context::llama_context(
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.training = params.training;
cparams.warmup = false;

cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
Expand Down Expand Up @@ -2297,6 +2298,7 @@ llama_context_params llama_context_default_params() {
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
/*.training =*/ false,
};

return result;
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct llama_cparams {
bool warmup;
bool op_offload;
bool kv_unified;
bool training;

enum llama_pooling_type pooling_type;

Expand Down
22 changes: 18 additions & 4 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1524,8 +1524,15 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = inp->get_kq_mask();

ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * k, * v;

if (loras && !loras->empty() && k_cur && v_cur && cparams.training) {
k = mctx_cur->get_k_lora(ctx0, k_cur, il);
v = mctx_cur->get_v_lora(ctx0, v_cur, il);
} else {
k = mctx_cur->get_k(ctx0, il);
v = mctx_cur->get_v(ctx0, il);
}

ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
Expand Down Expand Up @@ -1591,8 +1598,15 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();

ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * k, * v;

if (loras && !loras->empty() && k_cur && v_cur && cparams.training) {
k = mctx_cur->get_k_lora(ctx0, k_cur, il);
v = mctx_cur->get_v_lora(ctx0, v_cur, il);
} else {
k = mctx_cur->get_k(ctx0, il);
v = mctx_cur->get_v(ctx0, il);
}

ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
Expand Down
35 changes: 35 additions & 0 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,33 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}

ggml_tensor * llama_kv_cache::get_k_lora(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
if (sinfo.s0 == 0) {
return k_cur;
}

slot_info past_sinfo = sinfo;
past_sinfo.s0 = 0;
past_sinfo.s1 = sinfo.s0 - 1;

ggml_tensor * k_past = get_k(ctx, il, n_kv, past_sinfo);

return ggml_concat(ctx, k_past, k_cur, 2);
}

ggml_tensor * llama_kv_cache::get_v_lora(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
if (sinfo.s0 == 0) {
return v_cur;
}

slot_info past_sinfo = sinfo;
past_sinfo.s0 = 0;
past_sinfo.s1 = sinfo.s0 - 1;
ggml_tensor * v_past = get_v(ctx, il, n_kv, past_sinfo);

return ggml_concat(ctx, v_past, v_cur, 2);
}

ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;

Expand Down Expand Up @@ -1978,6 +2005,14 @@ ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
}

ggml_tensor * llama_kv_cache_context::get_k_lora(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
return kv->get_k_lora(ctx, k_cur, il, n_kv, sinfos[i_cur]);
}

ggml_tensor * llama_kv_cache_context::get_v_lora(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
return kv->get_v_lora(ctx, v_cur, il, n_kv, sinfos[i_cur]);
}

ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_k_idxs(ctx, ubatch);
}
Expand Down
8 changes: 8 additions & 0 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ class llama_kv_cache : public llama_memory_i {
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;

// gradient-aware retrieval for LoRA training
ggml_tensor * get_k_lora(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v_lora(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;

//
// preparation API
//
Expand Down Expand Up @@ -325,6 +329,10 @@ class llama_kv_cache_context : public llama_memory_context_i {
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;

// gradient-aware retrieval for LoRA training
ggml_tensor * get_k_lora(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
ggml_tensor * get_v_lora(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;

// create destination indices for each head of the current batch for where it would be written in the KV cache
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
// helps understand the implementation logic of cpy_k and cpy_v
Expand Down
Loading