Skip to content

Commit ccbc006

Browse files
authored
Merge pull request #23 from zoq/set-rows-fix
Ignore GGML_OP_SET_ROWS parameters during gradient calculation
2 parents 788b32c + d92c7e5 commit ccbc006

25 files changed

+1558
-26
lines changed

examples/training/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,9 @@ add_executable(${TARGET} finetune.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_11)
6+
7+
set(TARGET llama-finetune-lora)
8+
add_executable(${TARGET} finetune-lora.cpp)
9+
install(TARGETS ${TARGET} RUNTIME)
10+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
11+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/training/README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# llama.cpp/examples/training
22

3+
## finetune
34
This directory contains examples related to language model training using llama.cpp/GGML.
45
So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP.
56
Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory.
@@ -15,3 +16,67 @@ export model_name=llama_3.2-1b && export quantization=f32
1516
```
1617

1718
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.
19+
20+
21+
## finetune-lora
22+
23+
LoRA (Low-Rank Adaptation) fine-tuning for efficient model training. This approach trains only a small set of additional parameters while keeping
24+
the base model frozen, making it memory-efficient.
25+
26+
### Basic Usage
27+
28+
```sh
29+
# Create new LoRA adapter with default settings (rank=8, alpha=16, attention modules)
30+
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512
31+
32+
# Custom LoRA parameters(creates new lora adapter and trains it from scratch)
33+
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 \
34+
--lora-rank 16 --lora-alpha 32 --lora-modules "attn_q,attn_k,attn_v,attn_o"
35+
36+
# Fine-tune existing LoRA adapter
37+
./build/bin/llama-finetune-lora -m base_model.gguf -f dataset.txt --lora existing_adapter.gguf \
38+
--output-adapter improved_adapter.gguf -ngl 999 -c 512 -b 512 -ub 512
39+
```
40+
41+
42+
### Parameters
43+
44+
#### LoRA Configuration
45+
- `--lora-rank N` - LoRA rank (default: 8)
46+
- Lower rank = smaller adapter, less capacity
47+
- Higher rank = larger adapter, more capacity
48+
- `--lora-alpha N` - LoRA alpha scaling factor (default: 16.0)
49+
- Controls adaptation strength
50+
- Common rule: alpha = 2 × rank
51+
- `--lora-modules MODULES` - Target modules as comma-separated list
52+
- Available: `attn_q`, `attn_k`, `attn_v`, `attn_o`, `ffn_gate`, `ffn_up`, `ffn_down`, `embed`, `output`, `all`
53+
- Default: `attn_q,attn_k,attn_v,attn_o` (attention modules)
54+
- `--output-adapter PATH` - Output adapter filename (default: auto-generated)
55+
56+
#### Standard Parameters
57+
- `-m MODEL` - Base model file (.gguf)
58+
- `-f FILE` - Training dataset
59+
- `-ngl N` - GPU layers (use 999 for full GPU training)
60+
- `-c N` - Context length (512 recommended for mobile)
61+
62+
63+
### Using Trained Adapters
64+
65+
After training, you'll get a small adapter file. Use it with the original base model:
66+
67+
```sh
68+
./build/bin/llama-cli -m base_model.gguf --lora trained_adapter.gguf -ngl 999
69+
```
70+
71+
### Troubleshooting
72+
73+
- **Out of memory**: Reduce context length (`-c 256`), lower rank, or use fewer target modules
74+
- **Poor quality**: Increase rank, add more target modules, or train longer
75+
- **Large adapter**: Reduce rank or limit target modules
76+
77+
### Help
78+
79+
Run with `--help` or `-h` to see all available parameters:
80+
```sh
81+
./build/bin/llama-finetune-lora --help
82+
```
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
#include "arg.h"
2+
#include "common.h"
3+
#include "log.h"
4+
#include "llama.h"
5+
6+
#include <cmath>
7+
#include <cstdio>
8+
#include <cstring>
9+
#include <ctime>
10+
#include <vector>
11+
#include <fstream>
12+
13+
#if defined(_MSC_VER)
14+
#pragma warning(disable: 4244 4267) // possible loss of data
15+
#endif
16+
17+
18+
static uint32_t parse_lora_modules(const std::string& modules_str) {
19+
if (modules_str.empty()) {
20+
return LLAMA_LORA_TARGET_ATTN_Q | LLAMA_LORA_TARGET_ATTN_K | LLAMA_LORA_TARGET_ATTN_V | LLAMA_LORA_TARGET_ATTN_O;
21+
}
22+
23+
static const std::map<std::string, uint32_t> module_map = {
24+
{"attn_q", LLAMA_LORA_TARGET_ATTN_Q},
25+
{"attn_k", LLAMA_LORA_TARGET_ATTN_K},
26+
{"attn_v", LLAMA_LORA_TARGET_ATTN_V},
27+
{"attn_o", LLAMA_LORA_TARGET_ATTN_O},
28+
{"ffn_gate", LLAMA_LORA_TARGET_FFN_GATE},
29+
{"ffn_up", LLAMA_LORA_TARGET_FFN_UP},
30+
{"ffn_down", LLAMA_LORA_TARGET_FFN_DOWN},
31+
{"output", LLAMA_LORA_TARGET_OUTPUT},
32+
{"all", LLAMA_LORA_TARGET_ALL}
33+
};
34+
35+
uint32_t target_modules = 0;
36+
std::stringstream ss(modules_str);
37+
std::string module;
38+
39+
while (std::getline(ss, module, ',')) {
40+
module.erase(0, module.find_first_not_of(" \t"));
41+
module.erase(module.find_last_not_of(" \t") + 1);
42+
43+
auto it = module_map.find(module);
44+
if (it != module_map.end()) {
45+
target_modules |= it->second;
46+
LOG_INF("Added target module: %s\n", module.c_str());
47+
} else {
48+
LOG_ERR("Unknown LoRA target module: %s\n", module.c_str());
49+
LOG_ERR("Available modules: attn_q, attn_k, attn_v, attn_o, ffn_gate, ffn_up, ffn_down, output, all\n");
50+
return 0;
51+
}
52+
}
53+
54+
return target_modules;
55+
}
56+
57+
static void print_lora_usage() {
58+
printf("\nLoRA Fine-tuning Parameters:\n");
59+
printf(" --lora-rank N LoRA rank (default: 8, range: 1-512)\n");
60+
printf(" --lora-alpha N LoRA alpha scaling factor (default: 16.0, range: 0.1-1000.0)\n");
61+
printf(" --lora-modules MODULES Target modules as comma-separated list (default: attn_q,attn_k,attn_v,attn_o)\n");
62+
printf(" Available modules: attn_q, attn_k, attn_v, attn_o, ffn_gate, ffn_up, ffn_down, output, all\n");
63+
printf(" Examples: \"attn_q,attn_v\" or \"all\" or \"attn_q,attn_k,attn_v,attn_o,ffn_gate,ffn_up,ffn_down\"\n");
64+
printf(" --output-adapter PATH Output path for trained adapter (default: auto-generated)\n");
65+
printf("\nExamples:\n");
66+
printf(" # Train with rank=16, alpha=32, all attention modules\n");
67+
printf(" %s -m model.gguf -f dataset.txt --lora-rank 16 --lora-alpha 32 --lora-modules attn_q,attn_k,attn_v,attn_o\n", "finetune-lora");
68+
printf("\n # Fine-tune existing adapter with all modules\n");
69+
printf(" %s -m model.gguf -f dataset.txt --lora existing.gguf --output-adapter improved.gguf\n", "finetune-lora");
70+
printf("\n");
71+
}
72+
73+
int main(int argc, char ** argv) {
74+
common_params params;
75+
76+
int32_t lora_rank = 8;
77+
float lora_alpha = 16.0f;
78+
std::string lora_modules_str;
79+
std::string output_adapter_path;
80+
81+
params.escape = false;
82+
83+
auto remove_arg_pair = [&](int i) {
84+
for (int j = i; j < argc - 2; j++) {
85+
argv[j] = argv[j + 2];
86+
}
87+
argc -= 2;
88+
};
89+
90+
for (int i = 1; i < argc - 1; i++) {
91+
if (strcmp(argv[i], "--lora-rank") == 0) {
92+
lora_rank = std::atoi(argv[i + 1]);
93+
remove_arg_pair(i);
94+
i--;
95+
} else if (strcmp(argv[i], "--lora-alpha") == 0) {
96+
lora_alpha = std::atof(argv[i + 1]);
97+
remove_arg_pair(i);
98+
i--;
99+
} else if (strcmp(argv[i], "--lora-modules") == 0) {
100+
lora_modules_str = argv[i + 1];
101+
remove_arg_pair(i);
102+
i--;
103+
} else if (strcmp(argv[i], "--output-adapter") == 0) {
104+
output_adapter_path = argv[i + 1];
105+
remove_arg_pair(i);
106+
i--;
107+
}
108+
}
109+
110+
LOG_INF("Using LoRA parameters: rank=%d, alpha=%.1f\n", lora_rank, lora_alpha);
111+
112+
for (int i = 1; i < argc; i++) {
113+
if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) {
114+
print_lora_usage();
115+
}
116+
}
117+
118+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
119+
print_lora_usage();
120+
return 1;
121+
}
122+
123+
if (params.use_mmap) {
124+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
125+
params.use_mmap = false;
126+
}
127+
if (params.cache_type_k != GGML_TYPE_F32) {
128+
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
129+
params.cache_type_k = GGML_TYPE_F32;
130+
}
131+
if (params.cache_type_v != GGML_TYPE_F32) {
132+
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
133+
params.cache_type_v = GGML_TYPE_F32;
134+
}
135+
136+
common_init();
137+
llama_backend_init();
138+
llama_numa_init(params.numa);
139+
140+
common_init_result llama_init = common_init_from_params(params);
141+
llama_model_ptr & model = llama_init.model;
142+
llama_context_ptr & ctx = llama_init.context;
143+
144+
if (model == NULL) {
145+
LOG_ERR("%s: unable to load model\n", __func__);
146+
return 1;
147+
}
148+
149+
{
150+
LOG_INF("\n");
151+
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
152+
}
153+
154+
uint32_t target_modules = parse_lora_modules(lora_modules_str);
155+
if (target_modules == 0) {
156+
return 1;
157+
}
158+
159+
struct llama_lora_training_params lora_params = {
160+
/*target_modules =*/ target_modules,
161+
/*rank =*/ lora_rank,
162+
/*alpha =*/ lora_alpha,
163+
/*dropout =*/ 0.0f,
164+
/*init_std =*/ 0.02f,
165+
};
166+
167+
bool has_existing_lora = !params.lora_adapters.empty();
168+
struct llama_adapter_lora * trained_adapter = nullptr;
169+
170+
if (has_existing_lora) {
171+
LOG_INF("Finetuning existing LoRA adapters\n");
172+
LOG_INF("Found %zu existing LoRA adapters to train\n", params.lora_adapters.size());\
173+
trained_adapter = params.lora_adapters[0].ptr;
174+
if (!trained_adapter) {
175+
LOG_ERR("Existing LoRA adapter is null\n");
176+
return 1;
177+
}
178+
} else {
179+
LOG_INF("Target modules: Q=%s, K=%s, V=%s, O=%s, GATE=%s, UP=%s, DOWN=%s, OUTPUT=%s\n",
180+
(lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_Q) ? "yes" : "no",
181+
(lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_K) ? "yes" : "no",
182+
(lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_V) ? "yes" : "no",
183+
(lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_O) ? "yes" : "no",
184+
(lora_params.target_modules & LLAMA_LORA_TARGET_FFN_GATE) ? "yes" : "no",
185+
(lora_params.target_modules & LLAMA_LORA_TARGET_FFN_UP) ? "yes" : "no",
186+
(lora_params.target_modules & LLAMA_LORA_TARGET_FFN_DOWN) ? "yes" : "no",
187+
(lora_params.target_modules & LLAMA_LORA_TARGET_OUTPUT) ? "yes" : "no");
188+
189+
LOG_INF("LoRA configuration: rank=%d, alpha=%.1f (scaling=%.3f)\n",
190+
lora_params.rank, lora_params.alpha, lora_params.alpha / lora_params.rank);
191+
192+
trained_adapter = llama_lora_training_init(ctx.get(), model.get(), &lora_params);
193+
if (!trained_adapter) {
194+
LOG_ERR("%s: LoRA training initialization failed\n", __func__);
195+
return 1;
196+
}
197+
}
198+
199+
constexpr float val_split = 0.05f;
200+
201+
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
202+
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
203+
204+
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
205+
optimizer_params.adamw.alpha = 1e-5f; // learning rate
206+
207+
struct llama_opt_params lopt_params {
208+
/*n_ctx_train =*/ 0,
209+
/*param_filter =*/ llama_opt_param_filter_lora,
210+
/*param_filter_ud =*/ nullptr,
211+
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
212+
/*get_opt_pars_ud =*/ &optimizer_params,
213+
/*optimizer_type =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
214+
};
215+
llama_opt_init(ctx.get(), model.get(), lopt_params);
216+
217+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
218+
219+
ggml_opt_result_t result_train = ggml_opt_result_init();
220+
ggml_opt_result_t result_eval = ggml_opt_result_init();
221+
222+
for (int epoch = 0; epoch < 2; ++epoch) {
223+
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
224+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
225+
fprintf(stderr, "\n");
226+
227+
ggml_opt_result_reset(result_train);
228+
ggml_opt_result_reset(result_eval);
229+
}
230+
ggml_opt_result_free(result_train);
231+
ggml_opt_result_free(result_eval);
232+
233+
std::string adapter_filename;
234+
if (!output_adapter_path.empty()) {
235+
adapter_filename = output_adapter_path;
236+
} else if (has_existing_lora) {
237+
adapter_filename = "finetuned-lora-adapter.gguf";
238+
LOG_INF("Finetuned existing lora adapter, saving as: %s\n", adapter_filename.c_str());
239+
} else {
240+
adapter_filename = "trained-lora-adapter.gguf";
241+
LOG_INF("Saving new lora adapter: %s\n", adapter_filename.c_str());
242+
}
243+
244+
if (trained_adapter) {
245+
if (llama_lora_save_adapter(trained_adapter, adapter_filename.c_str(), model.get())) {
246+
std::ifstream adapter_file(adapter_filename, std::ios::binary | std::ios::ate);
247+
if (adapter_file.is_open()) {
248+
std::streamsize adapter_size = adapter_file.tellg();
249+
LOG_INF("LoRA adapter saved: %s (%.2f MB)\n",
250+
adapter_filename.c_str(), adapter_size / (1024.0 * 1024.0));
251+
adapter_file.close();
252+
}
253+
} else {
254+
LOG_ERR("Failed to save LoRA adapter\n");
255+
}
256+
} else {
257+
LOG_ERR("No trained adapter available for saving\n");
258+
}
259+
260+
llama_backend_free();
261+
262+
return 0;
263+
}

examples/training/finetune.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ int main(int argc, char ** argv) {
5454
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
5555
}
5656

57-
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
58-
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2);
57+
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
58+
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
5959

6060
struct lr_opt & lr = params.lr;
6161
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
@@ -64,7 +64,8 @@ int main(int argc, char ** argv) {
6464

6565
struct llama_opt_params lopt_params{
6666
/*n_ctx_train =*/0,
67-
/*param_filter =*/llama_opt_param_filter_all,
67+
// /*param_filter =*/llama_opt_param_filter_all,
68+
llama_opt_param_filter_lora,
6869
/*param_filter_ud =*/nullptr,
6970
/*get_opt_pars =*/common_opt_lr_pars,
7071
/*get_opt_pars_ud =*/&params.lr,

ggml/include/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ extern "C" {
479479
GGML_OP_REPEAT_BACK,
480480
GGML_OP_CONCAT,
481481
GGML_OP_SILU_BACK,
482+
GGML_OP_GEGLU_BACK,
482483
GGML_OP_NORM, // normalize
483484
GGML_OP_RMS_NORM,
484485
GGML_OP_RMS_NORM_BACK,
@@ -1130,6 +1131,12 @@ extern "C" {
11301131
struct ggml_tensor * a,
11311132
struct ggml_tensor * b);
11321133

1134+
GGML_API struct ggml_tensor * ggml_geglu_back(
1135+
struct ggml_context * ctx,
1136+
struct ggml_tensor * grad,
1137+
struct ggml_tensor * x,
1138+
struct ggml_tensor * g);
1139+
11331140
// hardswish(x) = x * relu6(x + 3) / 6
11341141
GGML_API struct ggml_tensor * ggml_hardswish(
11351142
struct ggml_context * ctx,

0 commit comments

Comments
 (0)