Skip to content

Commit 7f3cae5

Browse files
committed
lora: Fix LoRA K/V gradient flow with gradient-connected kv cache retrieval
Add get_k_lora() and get_v_lora() methods that use concatenation instead of ggml_view_4d to maintain gradient connectivity during training. This ensures LoRA K/V parameters receive proper gradients while preserving causal attention behavior.
1 parent 5545231 commit 7f3cae5

File tree

9 files changed

+68
-4
lines changed

9 files changed

+68
-4
lines changed

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11701170
cparams.no_perf = params.no_perf;
11711171
cparams.op_offload = !params.no_op_offload;
11721172
cparams.swa_full = params.swa_full;
1173+
cparams.training = params.training;
11731174
cparams.kv_unified = params.kv_unified;
11741175

11751176
cparams.type_k = params.cache_type_k;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ struct common_params {
390390
bool warmup = true; // warmup run
391391
bool check_tensors = false; // validate tensor data
392392
bool no_op_offload = false; // globally disable offload host tensor operations to device
393+
bool training = false; // enable training mode (affects LoRA K/V gradient flow)
393394
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
394395

395396
bool single_turn = false; // single turn chat conversation

examples/training/finetune-lora.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ int main(int argc, char ** argv) {
136136
common_init();
137137
llama_backend_init();
138138
llama_numa_init(params.numa);
139+
params.training = true;
139140

140141
common_init_result llama_init = common_init_from_params(params);
141142
llama_model_ptr & model = llama_init.model;

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ extern "C" {
346346
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
347347
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
348348
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
349+
bool training; // if true, we're in training mode (affects LoRA K/V gradient flow)
349350
};
350351

351352
// model quantization parameters

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ llama_context::llama_context(
4343
cparams.offload_kqv = params.offload_kqv;
4444
cparams.no_perf = params.no_perf;
4545
cparams.pooling_type = params.pooling_type;
46+
cparams.training = params.training;
4647
cparams.warmup = false;
4748

4849
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
@@ -2297,6 +2298,7 @@ llama_context_params llama_context_default_params() {
22972298
/*.op_offload =*/ true,
22982299
/*.swa_full =*/ true,
22992300
/*.kv_unified =*/ false,
2301+
/*.training =*/ false,
23002302
};
23012303

23022304
return result;

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct llama_cparams {
3333
bool warmup;
3434
bool op_offload;
3535
bool kv_unified;
36+
bool training;
3637

3738
enum llama_pooling_type pooling_type;
3839

src/llama-graph.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,8 +1524,15 @@ ggml_tensor * llm_graph_context::build_attn(
15241524
const auto & kq_mask = inp->get_kq_mask();
15251525

15261526
ggml_tensor * q = q_cur;
1527-
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1528-
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1527+
ggml_tensor * k, * v;
1528+
1529+
if (loras && !loras->empty() && k_cur && v_cur && cparams.training) {
1530+
k = mctx_cur->get_k_lora(ctx0, k_cur, il);
1531+
v = mctx_cur->get_v_lora(ctx0, v_cur, il);
1532+
} else {
1533+
k = mctx_cur->get_k(ctx0, il);
1534+
v = mctx_cur->get_v(ctx0, il);
1535+
}
15291536

15301537
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
15311538
cb(cur, "kqv_out", il);
@@ -1591,8 +1598,15 @@ ggml_tensor * llm_graph_context::build_attn(
15911598
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
15921599

15931600
ggml_tensor * q = q_cur;
1594-
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1595-
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1601+
ggml_tensor * k, * v;
1602+
1603+
if (loras && !loras->empty() && k_cur && v_cur && cparams.training) {
1604+
k = mctx_cur->get_k_lora(ctx0, k_cur, il);
1605+
v = mctx_cur->get_v_lora(ctx0, v_cur, il);
1606+
} else {
1607+
k = mctx_cur->get_k(ctx0, il);
1608+
v = mctx_cur->get_v(ctx0, il);
1609+
}
15961610

15971611
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
15981612
cb(cur, "kqv_out", il);

src/llama-kv-cache.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,33 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
11041104
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
11051105
}
11061106

1107+
ggml_tensor * llama_kv_cache::get_k_lora(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
1108+
if (sinfo.s0 == 0) {
1109+
return k_cur;
1110+
}
1111+
1112+
slot_info past_sinfo = sinfo;
1113+
past_sinfo.s0 = 0;
1114+
past_sinfo.s1 = sinfo.s0 - 1;
1115+
1116+
ggml_tensor * k_past = get_k(ctx, il, n_kv, past_sinfo);
1117+
1118+
return ggml_concat(ctx, k_past, k_cur, 2);
1119+
}
1120+
1121+
ggml_tensor * llama_kv_cache::get_v_lora(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
1122+
if (sinfo.s0 == 0) {
1123+
return v_cur;
1124+
}
1125+
1126+
slot_info past_sinfo = sinfo;
1127+
past_sinfo.s0 = 0;
1128+
past_sinfo.s1 = sinfo.s0 - 1;
1129+
ggml_tensor * v_past = get_v(ctx, il, n_kv, past_sinfo);
1130+
1131+
return ggml_concat(ctx, v_past, v_cur, 2);
1132+
}
1133+
11071134
ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
11081135
const uint32_t n_tokens = ubatch.n_tokens;
11091136

@@ -1978,6 +2005,14 @@ ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_
19782005
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
19792006
}
19802007

2008+
ggml_tensor * llama_kv_cache_context::get_k_lora(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
2009+
return kv->get_k_lora(ctx, k_cur, il, n_kv, sinfos[i_cur]);
2010+
}
2011+
2012+
ggml_tensor * llama_kv_cache_context::get_v_lora(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
2013+
return kv->get_v_lora(ctx, v_cur, il, n_kv, sinfos[i_cur]);
2014+
}
2015+
19812016
ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
19822017
return kv->build_input_k_idxs(ctx, ubatch);
19832018
}

src/llama-kv-cache.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ class llama_kv_cache : public llama_memory_i {
149149
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
150150
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
151151

152+
// gradient-aware retrieval for LoRA training
153+
ggml_tensor * get_k_lora(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
154+
ggml_tensor * get_v_lora(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
155+
152156
//
153157
// preparation API
154158
//
@@ -325,6 +329,10 @@ class llama_kv_cache_context : public llama_memory_context_i {
325329
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
326330
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
327331

332+
// gradient-aware retrieval for LoRA training
333+
ggml_tensor * get_k_lora(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
334+
ggml_tensor * get_v_lora(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
335+
328336
// create destination indices for each head of the current batch for where it would be written in the KV cache
329337
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
330338
// helps understand the implementation logic of cpy_k and cpy_v

0 commit comments

Comments
 (0)