|
| 1 | +#include "arg.h" |
| 2 | +#include "common.h" |
| 3 | +#include "log.h" |
| 4 | +#include "llama.h" |
| 5 | + |
| 6 | +#include <cmath> |
| 7 | +#include <cstdio> |
| 8 | +#include <cstring> |
| 9 | +#include <ctime> |
| 10 | +#include <vector> |
| 11 | +#include <fstream> |
| 12 | + |
| 13 | +#if defined(_MSC_VER) |
| 14 | +#pragma warning(disable: 4244 4267) // possible loss of data |
| 15 | +#endif |
| 16 | + |
| 17 | + |
| 18 | +static uint32_t parse_lora_modules(const std::string& modules_str) { |
| 19 | + if (modules_str.empty()) { |
| 20 | + return LLAMA_LORA_TARGET_ATTN_Q | LLAMA_LORA_TARGET_ATTN_K | LLAMA_LORA_TARGET_ATTN_V | LLAMA_LORA_TARGET_ATTN_O; |
| 21 | + } |
| 22 | + |
| 23 | + static const std::map<std::string, uint32_t> module_map = { |
| 24 | + {"attn_q", LLAMA_LORA_TARGET_ATTN_Q}, |
| 25 | + {"attn_k", LLAMA_LORA_TARGET_ATTN_K}, |
| 26 | + {"attn_v", LLAMA_LORA_TARGET_ATTN_V}, |
| 27 | + {"attn_o", LLAMA_LORA_TARGET_ATTN_O}, |
| 28 | + {"ffn_gate", LLAMA_LORA_TARGET_FFN_GATE}, |
| 29 | + {"ffn_up", LLAMA_LORA_TARGET_FFN_UP}, |
| 30 | + {"ffn_down", LLAMA_LORA_TARGET_FFN_DOWN}, |
| 31 | + {"output", LLAMA_LORA_TARGET_OUTPUT}, |
| 32 | + {"all", LLAMA_LORA_TARGET_ALL} |
| 33 | + }; |
| 34 | + |
| 35 | + uint32_t target_modules = 0; |
| 36 | + std::stringstream ss(modules_str); |
| 37 | + std::string module; |
| 38 | + |
| 39 | + while (std::getline(ss, module, ',')) { |
| 40 | + module.erase(0, module.find_first_not_of(" \t")); |
| 41 | + module.erase(module.find_last_not_of(" \t") + 1); |
| 42 | + |
| 43 | + auto it = module_map.find(module); |
| 44 | + if (it != module_map.end()) { |
| 45 | + target_modules |= it->second; |
| 46 | + LOG_INF("Added target module: %s\n", module.c_str()); |
| 47 | + } else { |
| 48 | + LOG_ERR("Unknown LoRA target module: %s\n", module.c_str()); |
| 49 | + LOG_ERR("Available modules: attn_q, attn_k, attn_v, attn_o, ffn_gate, ffn_up, ffn_down, output, all\n"); |
| 50 | + return 0; |
| 51 | + } |
| 52 | + } |
| 53 | + |
| 54 | + return target_modules; |
| 55 | +} |
| 56 | + |
| 57 | +static void print_lora_usage() { |
| 58 | + printf("\nLoRA Fine-tuning Parameters:\n"); |
| 59 | + printf(" --lora-rank N LoRA rank (default: 8, range: 1-512)\n"); |
| 60 | + printf(" --lora-alpha N LoRA alpha scaling factor (default: 16.0, range: 0.1-1000.0)\n"); |
| 61 | + printf(" --lora-modules MODULES Target modules as comma-separated list (default: attn_q,attn_k,attn_v,attn_o)\n"); |
| 62 | + printf(" Available modules: attn_q, attn_k, attn_v, attn_o, ffn_gate, ffn_up, ffn_down, output, all\n"); |
| 63 | + printf(" Examples: \"attn_q,attn_v\" or \"all\" or \"attn_q,attn_k,attn_v,attn_o,ffn_gate,ffn_up,ffn_down\"\n"); |
| 64 | + printf(" --output-adapter PATH Output path for trained adapter (default: auto-generated)\n"); |
| 65 | + printf("\nExamples:\n"); |
| 66 | + printf(" # Train with rank=16, alpha=32, all attention modules\n"); |
| 67 | + printf(" %s -m model.gguf -f dataset.txt --lora-rank 16 --lora-alpha 32 --lora-modules attn_q,attn_k,attn_v,attn_o\n", "finetune-lora"); |
| 68 | + printf("\n # Fine-tune existing adapter with all modules\n"); |
| 69 | + printf(" %s -m model.gguf -f dataset.txt --lora existing.gguf --output-adapter improved.gguf\n", "finetune-lora"); |
| 70 | + printf("\n"); |
| 71 | +} |
| 72 | + |
| 73 | +int main(int argc, char ** argv) { |
| 74 | + common_params params; |
| 75 | + |
| 76 | + int32_t lora_rank = 8; |
| 77 | + float lora_alpha = 16.0f; |
| 78 | + std::string lora_modules_str; |
| 79 | + std::string output_adapter_path; |
| 80 | + |
| 81 | + params.escape = false; |
| 82 | + |
| 83 | + auto remove_arg_pair = [&](int i) { |
| 84 | + for (int j = i; j < argc - 2; j++) { |
| 85 | + argv[j] = argv[j + 2]; |
| 86 | + } |
| 87 | + argc -= 2; |
| 88 | + }; |
| 89 | + |
| 90 | + for (int i = 1; i < argc - 1; i++) { |
| 91 | + if (strcmp(argv[i], "--lora-rank") == 0) { |
| 92 | + lora_rank = std::atoi(argv[i + 1]); |
| 93 | + remove_arg_pair(i); |
| 94 | + i--; |
| 95 | + } else if (strcmp(argv[i], "--lora-alpha") == 0) { |
| 96 | + lora_alpha = std::atof(argv[i + 1]); |
| 97 | + remove_arg_pair(i); |
| 98 | + i--; |
| 99 | + } else if (strcmp(argv[i], "--lora-modules") == 0) { |
| 100 | + lora_modules_str = argv[i + 1]; |
| 101 | + remove_arg_pair(i); |
| 102 | + i--; |
| 103 | + } else if (strcmp(argv[i], "--output-adapter") == 0) { |
| 104 | + output_adapter_path = argv[i + 1]; |
| 105 | + remove_arg_pair(i); |
| 106 | + i--; |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + LOG_INF("Using LoRA parameters: rank=%d, alpha=%.1f\n", lora_rank, lora_alpha); |
| 111 | + |
| 112 | + for (int i = 1; i < argc; i++) { |
| 113 | + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { |
| 114 | + print_lora_usage(); |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { |
| 119 | + print_lora_usage(); |
| 120 | + return 1; |
| 121 | + } |
| 122 | + |
| 123 | + if (params.use_mmap) { |
| 124 | + LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__); |
| 125 | + params.use_mmap = false; |
| 126 | + } |
| 127 | + if (params.cache_type_k != GGML_TYPE_F32) { |
| 128 | + LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); |
| 129 | + params.cache_type_k = GGML_TYPE_F32; |
| 130 | + } |
| 131 | + if (params.cache_type_v != GGML_TYPE_F32) { |
| 132 | + LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); |
| 133 | + params.cache_type_v = GGML_TYPE_F32; |
| 134 | + } |
| 135 | + |
| 136 | + common_init(); |
| 137 | + llama_backend_init(); |
| 138 | + llama_numa_init(params.numa); |
| 139 | + |
| 140 | + common_init_result llama_init = common_init_from_params(params); |
| 141 | + llama_model_ptr & model = llama_init.model; |
| 142 | + llama_context_ptr & ctx = llama_init.context; |
| 143 | + |
| 144 | + if (model == NULL) { |
| 145 | + LOG_ERR("%s: unable to load model\n", __func__); |
| 146 | + return 1; |
| 147 | + } |
| 148 | + |
| 149 | + { |
| 150 | + LOG_INF("\n"); |
| 151 | + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); |
| 152 | + } |
| 153 | + |
| 154 | + uint32_t target_modules = parse_lora_modules(lora_modules_str); |
| 155 | + if (target_modules == 0) { |
| 156 | + return 1; |
| 157 | + } |
| 158 | + |
| 159 | + struct llama_lora_training_params lora_params = { |
| 160 | + /*target_modules =*/ target_modules, |
| 161 | + /*rank =*/ lora_rank, |
| 162 | + /*alpha =*/ lora_alpha, |
| 163 | + /*dropout =*/ 0.0f, |
| 164 | + /*init_std =*/ 0.02f, |
| 165 | + }; |
| 166 | + |
| 167 | + bool has_existing_lora = !params.lora_adapters.empty(); |
| 168 | + struct llama_adapter_lora * trained_adapter = nullptr; |
| 169 | + |
| 170 | + if (has_existing_lora) { |
| 171 | + LOG_INF("Finetuning existing LoRA adapters\n"); |
| 172 | + LOG_INF("Found %zu existing LoRA adapters to train\n", params.lora_adapters.size());\ |
| 173 | + trained_adapter = params.lora_adapters[0].ptr; |
| 174 | + if (!trained_adapter) { |
| 175 | + LOG_ERR("Existing LoRA adapter is null\n"); |
| 176 | + return 1; |
| 177 | + } |
| 178 | + } else { |
| 179 | + LOG_INF("Target modules: Q=%s, K=%s, V=%s, O=%s, GATE=%s, UP=%s, DOWN=%s, OUTPUT=%s\n", |
| 180 | + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_Q) ? "yes" : "no", |
| 181 | + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_K) ? "yes" : "no", |
| 182 | + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_V) ? "yes" : "no", |
| 183 | + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_O) ? "yes" : "no", |
| 184 | + (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_GATE) ? "yes" : "no", |
| 185 | + (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_UP) ? "yes" : "no", |
| 186 | + (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_DOWN) ? "yes" : "no", |
| 187 | + (lora_params.target_modules & LLAMA_LORA_TARGET_OUTPUT) ? "yes" : "no"); |
| 188 | + |
| 189 | + LOG_INF("LoRA configuration: rank=%d, alpha=%.1f (scaling=%.3f)\n", |
| 190 | + lora_params.rank, lora_params.alpha, lora_params.alpha / lora_params.rank); |
| 191 | + |
| 192 | + trained_adapter = llama_lora_training_init(ctx.get(), model.get(), &lora_params); |
| 193 | + if (!trained_adapter) { |
| 194 | + LOG_ERR("%s: LoRA training initialization failed\n", __func__); |
| 195 | + return 1; |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + constexpr float val_split = 0.05f; |
| 200 | + |
| 201 | + std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true); |
| 202 | + ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); |
| 203 | + |
| 204 | + struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr); |
| 205 | + optimizer_params.adamw.alpha = 1e-5f; // learning rate |
| 206 | + |
| 207 | + struct llama_opt_params lopt_params { |
| 208 | + /*n_ctx_train =*/ 0, |
| 209 | + /*param_filter =*/ llama_opt_param_filter_lora, |
| 210 | + /*param_filter_ud =*/ nullptr, |
| 211 | + /*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params, |
| 212 | + /*get_opt_pars_ud =*/ &optimizer_params, |
| 213 | + /*optimizer_type =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW, |
| 214 | + }; |
| 215 | + llama_opt_init(ctx.get(), model.get(), lopt_params); |
| 216 | + |
| 217 | + const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split); |
| 218 | + |
| 219 | + ggml_opt_result_t result_train = ggml_opt_result_init(); |
| 220 | + ggml_opt_result_t result_eval = ggml_opt_result_init(); |
| 221 | + |
| 222 | + for (int epoch = 0; epoch < 2; ++epoch) { |
| 223 | + llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, |
| 224 | + ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); |
| 225 | + fprintf(stderr, "\n"); |
| 226 | + |
| 227 | + ggml_opt_result_reset(result_train); |
| 228 | + ggml_opt_result_reset(result_eval); |
| 229 | + } |
| 230 | + ggml_opt_result_free(result_train); |
| 231 | + ggml_opt_result_free(result_eval); |
| 232 | + |
| 233 | + std::string adapter_filename; |
| 234 | + if (!output_adapter_path.empty()) { |
| 235 | + adapter_filename = output_adapter_path; |
| 236 | + } else if (has_existing_lora) { |
| 237 | + adapter_filename = "finetuned-lora-adapter.gguf"; |
| 238 | + LOG_INF("Finetuned existing lora adapter, saving as: %s\n", adapter_filename.c_str()); |
| 239 | + } else { |
| 240 | + adapter_filename = "trained-lora-adapter.gguf"; |
| 241 | + LOG_INF("Saving new lora adapter: %s\n", adapter_filename.c_str()); |
| 242 | + } |
| 243 | + |
| 244 | + if (trained_adapter) { |
| 245 | + if (llama_lora_save_adapter(trained_adapter, adapter_filename.c_str(), model.get())) { |
| 246 | + std::ifstream adapter_file(adapter_filename, std::ios::binary | std::ios::ate); |
| 247 | + if (adapter_file.is_open()) { |
| 248 | + std::streamsize adapter_size = adapter_file.tellg(); |
| 249 | + LOG_INF("LoRA adapter saved: %s (%.2f MB)\n", |
| 250 | + adapter_filename.c_str(), adapter_size / (1024.0 * 1024.0)); |
| 251 | + adapter_file.close(); |
| 252 | + } |
| 253 | + } else { |
| 254 | + LOG_ERR("Failed to save LoRA adapter\n"); |
| 255 | + } |
| 256 | + } else { |
| 257 | + LOG_ERR("No trained adapter available for saving\n"); |
| 258 | + } |
| 259 | + |
| 260 | + llama_backend_free(); |
| 261 | + |
| 262 | + return 0; |
| 263 | +} |
0 commit comments