From ec4e2b8d82f22401c862b6a1a8e99ee6aca64245 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Fri, 24 Oct 2025 11:59:57 +0200 Subject: [PATCH 1/5] lora: Add Instruction Finetuning support - Add masked loss computation on assistant responses only - Implement Vulkan masked cross-entropy loss shader & count_equal shader - Support default ChatML template & custom jinja chat templates --- common/common.cpp | 289 ++++++++++++++++++ common/common.h | 5 + examples/training/README.md | 15 +- examples/training/finetune-lora.cpp | 55 +++- examples/training/finetune.cpp | 15 +- ggml/include/ggml-opt.h | 20 ++ ggml/include/ggml.h | 23 ++ ggml/src/ggml-cpu/ggml-cpu.c | 23 ++ ggml/src/ggml-cpu/ops.cpp | 284 +++++++++++++++++ ggml/src/ggml-cpu/ops.h | 3 + ggml/src/ggml-opt.cpp | 162 +++++++++- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 158 +++++++++- .../vulkan-shaders/count_equal_masked.comp | 46 +++ .../cross_entropy_loss_masked_back.comp | 115 +++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 + ggml/src/ggml.c | 82 ++++- include/llama.h | 2 + src/llama-context.cpp | 58 +++- src/llama-context.h | 2 + 19 files changed, 1325 insertions(+), 34 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/count_equal_masked.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_masked_back.comp diff --git a/common/common.cpp b/common/common.cpp index 2cf15697701..2763d97a076 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,8 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "chat.h" +#include #include #include @@ -1616,3 +1618,290 @@ float lr_opt::get_lr(float epoch) const { LOG_INF("epoch %.2g lr=%.2g\n", epoch, r); return r; } + +ggml_opt_dataset_t common_opt_sft_dataset_init( + struct llama_context * ctx, + const std::string & json_content, + int64_t stride, + const std::string & chat_template_path) { + using json = nlohmann::json; + + const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(ctx)); + common_chat_templates_ptr chat_templates; + std::string chat_template_source; + if (!chat_template_path.empty()) { + std::ifstream tmpl_file(chat_template_path); + if (!tmpl_file.is_open()) { + LOG_ERR("Warning: Failed to open chat template file: %s\n", chat_template_path.c_str()); + } else { + chat_template_source.assign(std::istreambuf_iterator(tmpl_file), std::istreambuf_iterator()); + tmpl_file.close(); + try { + chat_templates = common_chat_templates_init(llama_get_model(ctx), chat_template_source); + } catch (const std::exception & e) { + LOG_ERR("Warning: Failed to parse chat template '%s': %s\n", chat_template_path.c_str(), e.what()); + } + } + } + + std::vector conversations; + std::istringstream content_stream(json_content); + + std::string line; + while (std::getline(content_stream, line)) { + if (line.empty() || line[0] == '#') continue; + try { + json conv = json::parse(line); + if (conv.contains("messages") && conv["messages"].is_array()) { + conversations.push_back(conv); + } + } catch (const json::exception & e) { + LOG_DBG("Warning: Failed to parse JSON line: %s\n", e.what()); + } + } + + if (conversations.empty()) { + LOG_ERR("Error: No valid conversations found\n"); + return nullptr; + } + LOG_INF("Loaded %zu conversations\n", conversations.size()); + + const int64_t ne_datapoint = llama_n_ctx(ctx); + if (stride <= 0) stride = ne_datapoint; + if (stride > ne_datapoint) stride = ne_datapoint; + + std::vector> all_tokenized_data; + std::vector> all_assistant_masks; + + auto token_count_prefix = [&](const std::string & render, size_t char_count) -> size_t { + std::string prefix = render.substr(0, char_count); + auto t = common_tokenize(ctx, prefix, /*add_special=*/false, /*parse_special=*/true); + return t.size(); + }; + + const std::string START_TAG = "<|im_start|>"; + const std::string START_SYS = "<|im_start|>system\n"; + const std::string START_USR = "<|im_start|>user\n"; + const std::string START_AST = "<|im_start|>assistant\n"; + const std::string END_TAG = "<|im_end|>"; + const std::string NL = "\n"; + + for (size_t i = 0; i < conversations.size(); ++i) { + const auto & messages = conversations[i]["messages"]; + if (!messages.is_array() || messages.empty()) continue; + + std::string render; + + if (chat_templates) { + std::vector chat_msgs; + chat_msgs.reserve(messages.size()); + for (const auto & msg : messages) { + if (!msg.contains("role") || !msg.contains("content")) { + continue; + } + common_chat_msg chat_msg; + chat_msg.role = msg["role"].get(); + chat_msg.content = msg["content"].get(); + chat_msgs.push_back(std::move(chat_msg)); + } + + if (!chat_msgs.empty()) { + common_chat_templates_inputs inputs; + inputs.messages = std::move(chat_msgs); + inputs.add_generation_prompt = false; + inputs.use_jinja = true; + try { + render = common_chat_templates_apply(chat_templates.get(), inputs).prompt; + + size_t last_im_end = render.rfind("<|im_end|>"); + if (last_im_end != std::string::npos) { + size_t end_pos = last_im_end + 10; // length of "<|im_end|>" + // Remove any trailing whitespace/newlines after the final <|im_end|> + while (end_pos < render.size() && (render[end_pos] == '\n' || render[end_pos] == '\r' || render[end_pos] == ' ')) { + end_pos++; + } + if (end_pos < render.size()) { + render = render.substr(0, last_im_end + 10); // Keep only up to + } + } + } catch (const std::exception & e) { + LOG_WRN("Warning: chat template rendering failed for conversation %zu: %s. Falling back to default ChatML rendering.\n", + i, e.what()); + } + } + } + + if (render.empty()) { + render.reserve(4096); + for (const auto & msg : messages) { + if (!msg.contains("role") || !msg.contains("content")) continue; + const std::string role = msg["role"].get(); + const std::string content = msg["content"].get(); + + if (role == "system") { + render += START_SYS; render += content; render += END_TAG + NL; + } else if (role == "user") { + render += START_USR; render += content; render += END_TAG + NL; + } else if (role == "assistant") { + render += START_AST; render += content; render += END_TAG + NL; + } + } + } + + if (render.empty()) { + continue; + } + + struct Span { size_t lo, hi; }; + std::vector assistant_spans; + + { + size_t from = 0; + while (true) { + size_t open = render.find(START_AST, from); + if (open == std::string::npos) break; + + // Include the role token ("assistant") and everything through the closing tag/newlines + size_t lo = open + START_TAG.size(); + if (lo > render.size()) { + lo = render.size(); + } + + size_t close = render.find(END_TAG, open + START_AST.size()); + if (close == std::string::npos) { + assistant_spans.push_back({lo, render.size()}); + break; + } + + size_t hi = close + END_TAG.size(); + if (hi <= lo) { + lo = open; + hi = close + END_TAG.size(); + } + + assistant_spans.push_back({lo, std::min(hi, render.size())}); + + size_t next_from = hi; + from = next_from; + } + } + + if (assistant_spans.empty()) { + LOG_WRN("Conversation %zu has no assistant spans\n", i); + continue; + } + + auto tokens_full = common_tokenize(ctx, render, /*add_special=*/false, /*parse_special=*/true); + if (tokens_full.empty()) continue; + + std::vector assistant_mask(tokens_full.size(), 0); + size_t assistant_token_count = 0; + + for (const auto & sp : assistant_spans) { + size_t t_lo = token_count_prefix(render, sp.lo); + size_t t_hi = token_count_prefix(render, sp.hi); + if (t_lo > tokens_full.size()) t_lo = tokens_full.size(); + if (t_hi > tokens_full.size()) t_hi = tokens_full.size(); + + + for (size_t t = t_lo; t < t_hi; ++t) { + assistant_mask[t] = 1; + ++assistant_token_count; + } + } + + if (assistant_token_count == 0) { + LOG_WRN("Warning: Conversation %zu has zero assistant tokens after masking\n", i); + continue; + } + + all_tokenized_data.push_back(tokens_full); + all_assistant_masks.push_back(assistant_mask); + } + + if (all_tokenized_data.empty()) { + LOG_ERR("ERROR: No valid training samples generated after processing %zu conversations\n", conversations.size()); + return nullptr; + } + + std::vector> final_samples; + std::vector> final_masks; + + llama_token pad_token = llama_vocab_pad(vocab); + if (pad_token == LLAMA_TOKEN_NULL) { + pad_token = llama_vocab_eos(vocab); + } + + for (size_t i = 0; i < all_tokenized_data.size(); ++i) { + const auto& conv_tokens = all_tokenized_data[i]; + const auto& conv_mask = all_assistant_masks[i]; + + if ((int64_t)conv_tokens.size() > ne_datapoint) { + LOG_WRN("Skipping conversation %zu: too long (%zu tokens > %lld)\n", i, conv_tokens.size(), (long long)ne_datapoint); + continue; + } + + size_t conv_assistant_tokens = 0; + for (int32_t mask_val : conv_mask) { + if (mask_val == 1) conv_assistant_tokens++; + } + + if (conv_assistant_tokens == 0) { + LOG_WRN("Skipping conversation %zu: no assistant tokens\n", i); + continue; + } + + std::vector sample_tokens = conv_tokens; + std::vector sample_mask = conv_mask; + + sample_tokens.resize(ne_datapoint, pad_token); + sample_mask.resize(ne_datapoint, 0); // Padding tokens are not trained on + + final_samples.push_back(sample_tokens); + final_masks.push_back(sample_mask); + } + + all_tokenized_data = std::move(final_samples); + all_assistant_masks = std::move(final_masks); + + const int64_t ndata = all_tokenized_data.size(); + + ggml_opt_dataset_t result = ggml_opt_dataset_init_with_masks( + GGML_TYPE_I32, GGML_TYPE_I32, GGML_TYPE_I32, + /*ne_datapoint=*/ne_datapoint, /*ne_label=*/ne_datapoint, /*ne_mask=*/ne_datapoint, + /*ndata=*/ndata, /*ndata_shard=*/1); + + if (result == nullptr) { + return nullptr; + } + + int32_t * data = (int32_t *) ggml_opt_dataset_data(result)->data; + int32_t * labels = (int32_t *) ggml_opt_dataset_labels(result)->data; + int32_t * masks = (int32_t *) ggml_opt_dataset_masks(result)->data; + + for (int64_t idata = 0; idata < ndata; ++idata) { + const auto & sample_tokens = all_tokenized_data[idata]; + const auto & sample_mask = all_assistant_masks[idata]; + + // inputs + for (int64_t i = 0; i < ne_datapoint; ++i) { + data[idata * ne_datapoint + i] = sample_tokens[i]; + } + + // labels: Set actual next tokens for ALL positions (masked cross-entropy needs real tokens) + for (int64_t i = 0; i < ne_datapoint - 1; ++i) { + // Always set the actual next token - masking is handled separately + labels[idata * ne_datapoint + i] = sample_tokens[i + 1]; + } + labels[idata * ne_datapoint + (ne_datapoint - 1)] = sample_tokens[ne_datapoint - 1]; // last token predicts itself (will be masked) + + // masks: indicate which preds should be trained on (shifted by 1 from sample_mask) + // Since we predict token[i+1] from token[i], we train when token[i+1] is assistant + for (int64_t i = 0; i < ne_datapoint - 1; ++i) { + masks[idata * ne_datapoint + i] = (i + 1 < ne_datapoint && sample_mask[i + 1] == 1) ? 1 : 0; + } + masks[idata * ne_datapoint + (ne_datapoint - 1)] = 0; + } + + return result; +} diff --git a/common/common.h b/common/common.h index 8043c68c7c3..d0fc268ca6f 100644 --- a/common/common.h +++ b/common/common.h @@ -745,3 +745,8 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std // "adamw" or "sgd" (case insensitive) enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *); +ggml_opt_dataset_t common_opt_sft_dataset_init( + struct llama_context * ctx, + const std::string & json_content, + int64_t stride, + const std::string & chat_template_path = ""); diff --git a/examples/training/README.md b/examples/training/README.md index 7d1cda2a9ca..c2ad915479d 100644 --- a/examples/training/README.md +++ b/examples/training/README.md @@ -27,11 +27,11 @@ the base model frozen, making it memory-efficient. ```sh # Create new LoRA adapter with default settings (rank=8, alpha=16, attention modules) -./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 +./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 -fa off # Custom LoRA parameters(creates new lora adapter and trains it from scratch) ./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 \ - --lora-rank 16 --lora-alpha 32 --lora-modules "attn_q,attn_k,attn_v,attn_o" + --lora-rank 16 --lora-alpha 32 --lora-modules "attn_q,attn_k,attn_v,attn_o" -fa off # Fine-tune existing LoRA adapter ./build/bin/llama-finetune-lora -m base_model.gguf -f dataset.txt --lora existing_adapter.gguf \ @@ -44,8 +44,17 @@ the base model frozen, making it memory-efficient. # Resume training from checkpoint ./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 \ --resume-from "./lora_checkpoints/checkpoint_step_00000150/" + --output-adapter improved_adapter.gguf -ngl 999 -c 512 -b 512 -ub 512 -fa off + +# Supervised FineTuning with Assistant only loss +./build/bin/llama-finetune-lora -m model.gguf -f dataset.jsonl -ngl 999 -c 512 -b 512 -ub 512 \ + --lora-modules "attn_q,attn_k,attn_v,attn_o" --assistant-loss-only -fa off ``` +### SFT(Instruction Fine Tuning) with Assistant Only Loss +- Masks the system and user tokens and only computes loss on assistant tokens +- Requires the dataset to be in json format just like huggingface with `role` and `content` for each role +- Allows users to optionally pass a jinja chat template with `--chat-template chat-ml-template.jinja` ### Parameters @@ -60,6 +69,8 @@ the base model frozen, making it memory-efficient. - Available: `attn_q`, `attn_k`, `attn_v`, `attn_o`, `ffn_gate`, `ffn_up`, `ffn_down`, `embed`, `output`, `all` - Default: `attn_q,attn_k,attn_v,attn_o` (attention modules) - `--output-adapter PATH` - Output adapter filename (default: auto-generated) +- `--assistant-loss-only` - Trains only on assistant tokens +- `--chat-template` - Jinja chat template for chat ML formatting to train on assistant tokens only #### Checkpointing - `--checkpoint-save-steps N` - Save checkpoint every N training steps (default: 100) diff --git a/examples/training/finetune-lora.cpp b/examples/training/finetune-lora.cpp index fefe858ef82..dc920f36ab1 100644 --- a/examples/training/finetune-lora.cpp +++ b/examples/training/finetune-lora.cpp @@ -132,6 +132,9 @@ static void print_lora_usage() { printf(" --output-adapter PATH Output path for trained adapter (default: auto-generated)\n"); printf("\nTraining Options:\n"); printf(" --num-epochs N Number of training epochs (default: 1)\n"); + printf(" --assistant-loss-only Use JSON dataset format with masked loss (ChatML/conversation format)\n"); + printf(" Only computes loss on assistant responses, not system/user prompts\n"); + printf(" --chat-template PATH Optional Jinja chat template to render JSON dataset (matches HF apply_chat_template)\n"); printf("\nCheckpointing Options:\n"); printf(" --checkpoint-save-steps N Save checkpoint every N training steps (default: 100)\n"); printf(" --checkpoint-save-dir PATH Directory for checkpoints (default: ./checkpoints)\n"); @@ -142,6 +145,8 @@ static void print_lora_usage() { printf(" %s -m model.gguf -f dataset.txt --lora-rank 16 --lora-alpha 32 --lora-modules attn_q,attn_k,attn_v,attn_o\n", "finetune-lora"); printf("\n # Fine-tune existing adapter with all modules\n"); printf(" %s -m model.gguf -f dataset.txt --lora existing.gguf --output-adapter improved.gguf\n", "finetune-lora"); + printf("\n # Instruction fine-tuning with ChatML format\n"); + printf(" %s -m model.gguf -f conversations.jsonl --assistant-loss-only --lora-rank 16\n", "finetune-lora"); printf("\n"); } @@ -343,6 +348,8 @@ struct finetune_params { std::string checkpoint_save_dir = "./checkpoints"; std::string resume_from_checkpoint; bool auto_resume = false; + std::string chat_template_path; + bool assistant_loss_only = false; }; static bool parse_finetune_args(int& argc, char** argv, finetune_params& ft_params) { @@ -353,6 +360,21 @@ static bool parse_finetune_args(int& argc, char** argv, finetune_params& ft_para argc -= 2; }; + auto remove_arg_single = [&](int i) { + for (int j = i; j < argc - 1; j++) { + argv[j] = argv[j + 1]; + } + argc -= 1; + }; + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--assistant-loss-only") == 0) { + ft_params.assistant_loss_only = true; + remove_arg_single(i); + i--; + } + } + for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "--lora-rank") == 0 && i + 1 < argc) { ft_params.lora_rank = std::atoi(argv[i + 1]); @@ -393,6 +415,10 @@ static bool parse_finetune_args(int& argc, char** argv, finetune_params& ft_para } argc--; i--; + } else if (strcmp(argv[i], "--chat-template") == 0) { + ft_params.chat_template_path = argv[i + 1]; + remove_arg_pair(i); + i--; } } @@ -428,6 +454,10 @@ int main(int argc, char ** argv) { } } + if (!ft_params.resume_from_checkpoint.empty()) { + params.warmup = false; + } + // Load checkpoint LoRA adapter from directory structure (model.gguf) if (!ft_params.resume_from_checkpoint.empty()) { std::filesystem::path checkpoint_dir(ft_params.resume_from_checkpoint); @@ -512,9 +542,9 @@ int main(int argc, char ** argv) { (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_UP) ? "yes" : "no", (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_DOWN) ? "yes" : "no", (lora_params.target_modules & LLAMA_LORA_TARGET_OUTPUT) ? "yes" : "no"); - + LOG_INF("LoRA configuration: rank=%d, alpha=%.1f (scaling=%.3f)\n", - lora_params.rank, lora_params.alpha, lora_params.alpha / lora_params.rank); + lora_params.rank, lora_params.alpha, lora_params.alpha / lora_params.rank); trained_adapter = llama_lora_training_init(ctx.get(), model.get(), &lora_params); if (!trained_adapter) { @@ -525,8 +555,21 @@ int main(int argc, char ** argv) { constexpr float val_split = 0.05f; - std::vector tokens = common_tokenize(ctx.get(), params.prompt, true); - ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); + ggml_opt_dataset_t dataset; + + if (ft_params.assistant_loss_only) { + LOG_INF("Using JSON dataset with chat template and assistant-only loss\n"); + dataset = common_opt_sft_dataset_init(ctx.get(), params.prompt, llama_n_ctx(ctx.get())/2, ft_params.chat_template_path); + } else { + std::vector tokens = common_tokenize(ctx.get(), params.prompt, true); + LOG_INF("Using standard next-token prediction mode\n"); + dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); + } + + if (dataset == nullptr) { + LOG_ERR("Failed to create dataset. Please check your input file and parameters.\n"); + return 1; + } int start_epoch = 0; int64_t start_step = 0; @@ -557,9 +600,10 @@ int main(int argc, char ** argv) { LOG_ERR("Failed to load checkpoint, starting from scratch\n"); } } - + struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr); optimizer_params.adamw.alpha = 1e-5f; // learning rate + optimizer_params.adamw.wd = 0.01f; std::string optimizer_checkpoint_path; if (checkpoint_loaded && !ft_params.resume_from_checkpoint.empty()) { @@ -576,6 +620,7 @@ int main(int argc, char ** argv) { /*optimizer_type =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW, /*checkpoint_path =*/ checkpoint_loaded ? optimizer_checkpoint_path.c_str() : nullptr, /*load_optimizer_state =*/ checkpoint_loaded, + /*assistant_loss_only =*/ ft_params.assistant_loss_only, }; llama_opt_init(ctx.get(), model.get(), lopt_params); diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp index 82a6b75e6fb..f8675a9fab2 100644 --- a/examples/training/finetune.cpp +++ b/examples/training/finetune.cpp @@ -133,14 +133,15 @@ int main(int argc, char ** argv) { (unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split); struct llama_opt_params lopt_params { - /*n_ctx_train =*/ 0, - /*param_filter =*/ llama_opt_param_filter_all, - /*param_filter_ud =*/ nullptr, - /*get_opt_pars =*/ common_opt_lr_pars, - /*get_opt_pars_ud =*/ ¶ms.lr, - /*optimizer_type =*/ params.optimizer, - /*checkpoint_path =*/ nullptr, + /*n_ctx_train =*/ 0, + /*param_filter =*/ llama_opt_param_filter_all, + /*param_filter_ud =*/ nullptr, + /*get_opt_pars =*/ common_opt_lr_pars, + /*get_opt_pars_ud =*/ ¶ms.lr, + /*optimizer_type =*/ params.optimizer, + /*checkpoint_path =*/ nullptr, /*load_optimizer_state =*/ false, + /*assistant_loss_only =*/ false, }; llama_opt_init(ctx.get(), model.get(), lopt_params); diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index 05f4482e414..89de0e17220 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -31,6 +31,7 @@ extern "C" { GGML_OPT_LOSS_TYPE_MEAN, GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, + GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, }; @@ -43,12 +44,23 @@ extern "C" { int64_t ne_label, // number of elements per label int64_t ndata, // total number of datapoints/labels int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + + GGML_API ggml_opt_dataset_t ggml_opt_dataset_init_with_masks( + enum ggml_type type_data, // the type for the internal data tensor + enum ggml_type type_label, // the type for the internal labels tensor + enum ggml_type type_mask, // the type for the internal masks tensor + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ne_mask, // number of elements per mask + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); // get underlying tensors that store the data GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset); GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] + GGML_API struct ggml_tensor * ggml_opt_dataset_masks (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata], can be null // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata); @@ -65,6 +77,13 @@ extern "C" { size_t nb_data_batch, void * labels_batch, int64_t ibatch); + GGML_API void ggml_opt_dataset_get_batch_host_with_masks( + ggml_opt_dataset_t dataset, + void * data_batch, + size_t nb_data_batch, + void * labels_batch, + void * masks_batch, + int64_t ibatch); // ====== Model / Context ====== @@ -148,6 +167,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against + GGML_API struct ggml_tensor * ggml_opt_masks( ggml_opt_context_t opt_ctx); // assistant masks for instruction tuning, can be null GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2e8d1deed2a..1f4427d7afc 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -475,6 +475,7 @@ extern "C" { GGML_OP_MEAN, GGML_OP_ARGMAX, GGML_OP_COUNT_EQUAL, + GGML_OP_COUNT_EQUAL_MASKED, GGML_OP_REPEAT, GGML_OP_REPEAT_BACK, GGML_OP_CONCAT, @@ -551,6 +552,8 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, + GGML_OP_CROSS_ENTROPY_LOSS_MASKED, + GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK, GGML_OP_OPT_STEP_ADAMW, GGML_OP_OPT_STEP_SGD, @@ -992,6 +995,13 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // count number of equal elements in a and b, but only where mask=1 + GGML_API struct ggml_tensor * ggml_count_equal_masked( + struct ggml_context * ctx, + struct ggml_tensor * a, // predictions + struct ggml_tensor * b, // targets + struct ggml_tensor * c); // mask (1 for positions to count, 0 to skip) + // if a is the same shape as b, and a is not parameter, return a // otherwise, return a new tensor: repeat(a) to fit in b GGML_API struct ggml_tensor * ggml_repeat( @@ -2389,6 +2399,19 @@ extern "C" { struct ggml_tensor * b, // labels struct ggml_tensor * c); // gradients of cross_entropy_loss result + // Masked cross-entropy loss for instruction fine-tuning (assistant-only loss) + GGML_API struct ggml_tensor * ggml_cross_entropy_loss_masked( + struct ggml_context * ctx, + struct ggml_tensor * a, // logits + struct ggml_tensor * b, // labels + struct ggml_tensor * c); // mask (1 for assistant tokens, 0 for masked) + GGML_API struct ggml_tensor * ggml_cross_entropy_loss_masked_back( + struct ggml_context * ctx, + struct ggml_tensor * a, // logits + struct ggml_tensor * b, // labels + struct ggml_tensor * c, // mask + struct ggml_tensor * d); // gradients of cross_entropy_loss result + // AdamW optimizer step // Paper: https://arxiv.org/pdf/1711.05101v3.pdf // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 4cad419e70c..d664dc5c668 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1743,6 +1743,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_count_equal(params, tensor); } break; + case GGML_OP_COUNT_EQUAL_MASKED: + { + ggml_compute_forward_count_equal_masked(params, tensor); + } break; case GGML_OP_REPEAT: { ggml_compute_forward_repeat(params, tensor); @@ -2032,6 +2036,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_cross_entropy_loss_back(params, tensor); } break; + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED: + { + ggml_compute_forward_cross_entropy_loss_masked(params, tensor); + } + break; + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK: + { + ggml_compute_forward_cross_entropy_loss_masked_back(params, tensor); + } + break; case GGML_OP_OPT_STEP_ADAMW: { ggml_compute_forward_opt_step_adamw(params, tensor); @@ -2161,6 +2175,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = 1; } break; case GGML_OP_COUNT_EQUAL: + case GGML_OP_COUNT_EQUAL_MASKED: { n_tasks = n_threads; } break; @@ -2347,6 +2362,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED: + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: { @@ -2726,6 +2743,7 @@ struct ggml_cplan ggml_graph_plan( } } break; case GGML_OP_COUNT_EQUAL: + case GGML_OP_COUNT_EQUAL_MASKED: { cur = ggml_type_size(node->type)*n_tasks; } break; @@ -2840,6 +2858,11 @@ struct ggml_cplan ggml_graph_plan( { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED: + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK: + { + cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks) + sizeof(int64_t)*n_tasks; + } break; case GGML_OP_COUNT: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 49f2c898312..95e5f70d82f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2409,6 +2409,100 @@ void ggml_compute_forward_count_equal( } } +// ggml_compute_forward_count_equal_masked + +static void ggml_compute_forward_count_equal_masked_i32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT(src0->type == GGML_TYPE_I32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(src2->ne[1] == src0->ne[0] || ggml_are_same_shape(src0, src2)); + GGML_ASSERT(ggml_is_scalar(dst)); + GGML_ASSERT(dst->type == GGML_TYPE_I64); + + const int64_t nr = ggml_nrows(src0); + + const int ith = params->ith; + const int nth = params->nth; + + int64_t * sums = (int64_t *) params->wdata; + int64_t sum_thread = 0; + + const int64_t dr = (nr + nth - 1)/nth; + + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02*ne01); + const int64_t i02 = (ir - i03*ne03) / ne01; + const int64_t i01 = ir - i03*ne03 - i02*ne02; + + const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01; + const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11; + const char * data2 = (const char *) src2->data + i03*src2->nb[3] + i02*src2->nb[2] + i01*src2->nb[1]; + + for (int64_t i00 = 0; i00 < ne00; ++i00) { + const int32_t val0 = *((const int32_t *) (data0 + i00*nb00)); + const int32_t val1 = *((const int32_t *) (data1 + i00*nb10)); + + float mask_val; + if (ggml_are_same_shape(src0, src2)) { + mask_val = *((const float *) (data2 + i00*src2->nb[0])); + } else { + const char * mask_ptr = (const char *) src2->data + 0*src2->nb[0] + i00*src2->nb[1] + 0*src2->nb[2]; + mask_val = *((const float *) mask_ptr); + } + + const bool mask = mask_val > 0.5f; + + if (mask == 1) { + sum_thread += val0 == val1; + } + } + } + if (ith != 0) { + sums[ith] = sum_thread; + } + ggml_barrier(params->threadpool); + + if (ith != 0) { + return; + } + + for (int ith_other = 1; ith_other < nth; ++ith_other) { + sum_thread += sums[ith_other]; + } + *((int64_t *) dst->data) = sum_thread; +} + +void ggml_compute_forward_count_equal_masked( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_I32: + { + ggml_compute_forward_count_equal_masked_i32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_repeat static void ggml_compute_forward_repeat_f32( @@ -10984,6 +11078,196 @@ void ggml_compute_forward_cross_entropy_loss_back( } } +// ggml_compute_forward_cross_entropy_loss_masked_f32 + +static void ggml_compute_forward_cross_entropy_loss_masked_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; // logits + const ggml_tensor * src1 = dst->src[1]; // targets + const ggml_tensor * src2 = dst->src[2]; // mask (1 for assistant tokens, 0 for masked) + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); + GGML_ASSERT(src2->nb[0] == ggml_type_size(src2->type)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_scalar(dst)); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t nc = src0->ne[0]; + const int64_t nr = ggml_nrows(src0); + + const int ith = params->ith; + const int nth = params->nth; + + float * sums = (float *) params->wdata; + float * st = ((float *) params->wdata) + nth + ith*nc; + float sum_thread = 0.0f; + int64_t valid_tokens_thread = 0; + + GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc) + sizeof(int64_t) * nth); + int64_t * valid_counts = (int64_t *)(((float *) params->wdata) + nth + nth * nc); + + const int64_t dr = (nr + nth - 1)/nth; + + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + for (int64_t i1 = ir0; i1 < ir1; ++i1) { + const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]); + const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]); + const float * mask_row = (const float *)((const char *) src2->data + i1*src2->nb[1]); + const float mask_value = mask_row[0]; + + if (mask_value <= 0.5f) continue; + + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + + const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max); + assert(sum_softmax >= 0.0); + + ggml_vec_add1_f32(nc, st, st, -sum_softmax); + + float sum_st = 0.0f; + for (int64_t i = 0; i < nc; i++) { + sum_st += st[i] * s1[i]; + } + + sum_thread += sum_st; + valid_tokens_thread++; + } + + sums[ith] = sum_thread; + valid_counts[ith] = valid_tokens_thread; + ggml_barrier(params->threadpool); + + if (ith == 0) { + float total_loss = 0.0f; + int64_t total_valid = 0; + + for (int i = 0; i < nth; i++) { + total_loss += sums[i]; + total_valid += valid_counts[i]; + } + + float * dp = (float *) dst->data; + if (total_valid > 0) { + float final_loss = -total_loss / (float)total_valid; + dp[0] = final_loss; + } else { + dp[0] = 0.0f; + } + } +} + +// ggml_compute_forward_cross_entropy_loss_masked_back_f32 + +static void ggml_compute_forward_cross_entropy_loss_masked_back_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * grad = dst->src[0]; + const ggml_tensor * src0f = dst->src[1]; + const ggml_tensor * src1f = dst->src[2]; + const ggml_tensor * src2f = dst->src[3]; + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); + GGML_ASSERT(ggml_is_contiguous(src2f)); + GGML_ASSERT(ggml_is_contiguous(grad)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst)); + + const int64_t ith = params->ith; + const int64_t nth = params->nth; + + const int64_t nc = src0f->ne[0]; + const int64_t nr = ggml_nrows(src0f); + + int64_t total_valid = 0; + for (int64_t i1 = 0; i1 < nr; i1++) { + const float * mask_row = (const float *)((const char *) src2f->data + i1*src2f->nb[1]); + const float mask_value = mask_row[0]; + if (mask_value > 0.5f) { + total_valid++; + } + } + + const int64_t dr = (nr + nth - 1)/nth; + + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + const float upstream_grad = ((const float *) grad->data)[0]; + + float d_scale = 0.0f; + if (total_valid > 0) { + d_scale = upstream_grad / (float) total_valid; + } + + for (int64_t i1 = ir0; i1 < ir1; i1++) { + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]); + const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]); + const float * mask_row = (const float *)((const char *) src2f->data + i1*src2f->nb[1]); + const float mask_value = mask_row[0]; + + if (mask_value > 0.5f) { + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); + assert(sum > 0.0); + ggml_vec_scale_f32(nc, ds0, 1.0/sum); + + ggml_vec_sub_f32(nc, ds0, ds0, s1); + ggml_vec_scale_f32(nc, ds0, d_scale); + } else { + ggml_vec_set_f32(nc, ds0, 0.0f); + + } + } +} + +void ggml_compute_forward_cross_entropy_loss_masked( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_masked_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +void ggml_compute_forward_cross_entropy_loss_masked_back( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_masked_back_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + static void ggml_compute_forward_opt_step_adamw_f32( const ggml_compute_params * params, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0aeb5e24b0d..00b31283111 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -37,6 +37,7 @@ void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, st void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_count_equal_masked(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_repeat(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_repeat_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_concat(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -108,6 +109,8 @@ void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, void ggml_compute_forward_custom(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_cross_entropy_loss_masked(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_cross_entropy_loss_masked_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index 4aad7cb154e..f77933b29eb 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -18,12 +18,13 @@ struct ggml_opt_dataset { ggml_backend_buffer_t buf = nullptr; struct ggml_tensor * data = nullptr; struct ggml_tensor * labels = nullptr; + struct ggml_tensor * masks = nullptr; int64_t ndata = -1; int64_t ndata_shard = -1; size_t nbs_data = -1; size_t nbs_labels = -1; - + size_t nbs_masks = -1; std::vector permutation; }; @@ -45,10 +46,12 @@ struct ggml_opt_context { struct ggml_tensor * inputs = nullptr; struct ggml_tensor * outputs = nullptr; struct ggml_tensor * labels = nullptr; + struct ggml_tensor * masks = nullptr; struct ggml_tensor * loss = nullptr; struct ggml_tensor * pred = nullptr; struct ggml_tensor * ncorrect = nullptr; + struct ggml_tensor * nmasked = nullptr; struct ggml_cgraph * gf = nullptr; struct ggml_cgraph * gb_grad = nullptr; @@ -76,6 +79,7 @@ struct ggml_opt_result { std::vector loss; std::vector pred; int64_t ncorrect = 0; + int64_t nmasked = 0; int64_t opt_period = -1; bool loss_per_datapoint = false; @@ -129,6 +133,63 @@ ggml_opt_dataset_t ggml_opt_dataset_init( return result; } +ggml_opt_dataset_t ggml_opt_dataset_init_with_masks( + enum ggml_type type_data, + enum ggml_type type_label, + enum ggml_type type_mask, + int64_t ne_datapoint, + int64_t ne_label, + int64_t ne_mask, + int64_t ndata, + int64_t ndata_shard) { + GGML_ASSERT(ne_datapoint > 0); + GGML_ASSERT(ne_label >= 0); + GGML_ASSERT(ne_mask >= 0); + GGML_ASSERT(ndata > 0); + GGML_ASSERT(ndata_shard > 0); + + ggml_opt_dataset_t result = new ggml_opt_dataset; + result->ndata = ndata; + result->ndata_shard = ndata_shard; + + { + struct ggml_init_params params = { + /*.mem_size =*/ 3*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx = ggml_init(params); + } + + result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata); + result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata; + + if (ne_label > 0) { + result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata); + result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata; + } else { + result->labels = nullptr; + result->nbs_labels = 0; + } + + if (ne_mask > 0) { + result->masks = ggml_new_tensor_2d(result->ctx, type_mask, ne_mask, ndata); + result->nbs_masks = ggml_nbytes(result->masks) * ndata_shard/ndata; + } else { + result->masks = nullptr; + result->nbs_masks = 0; + } + + result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type()); + + const int64_t nshards = ndata/ndata_shard; + result->permutation.resize(nshards); + for (int64_t i = 0; i < nshards; ++i) { + result->permutation[i] = i; + } + return result; +} + void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) { ggml_backend_buffer_free(dataset->buf); ggml_free(dataset->ctx); @@ -147,6 +208,10 @@ struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) { return dataset->labels; } +struct ggml_tensor * ggml_opt_dataset_masks(ggml_opt_dataset_t dataset) { + return dataset->masks; +} + void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) { GGML_ASSERT(idata <= dataset->ndata); @@ -218,6 +283,36 @@ void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_bat } } +void ggml_opt_dataset_get_batch_host_with_masks(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, void * masks_batch, int64_t ibatch) { + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT((masks_batch == nullptr) == (dataset->masks == nullptr)); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data; + char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data; + memcpy(ptr_data_batch, ptr_data, dataset->nbs_data); + + if (labels_batch) { + const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels; + char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels; + memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels); + } + + if (masks_batch) { + const char * ptr_masks = (const char *) dataset->masks->data + ishard *dataset->nbs_masks; + char * ptr_masks_batch = (char *) masks_batch + ishard_batch*dataset->nbs_masks; + memcpy(ptr_masks_batch, ptr_masks, dataset->nbs_masks); + } + } +} + // ====== Model / Context ====== struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) { @@ -412,6 +507,22 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) { opt_ctx->loss_per_datapoint = true; break; } + case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED: { + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->masks = ggml_new_tensor(ctx_results, GGML_TYPE_F32, GGML_MAX_DIMS, opt_ctx->outputs->ne); + ggml_set_input(opt_ctx->masks); + ggml_set_name(opt_ctx->masks, "masks"); + opt_ctx->loss = ggml_cross_entropy_loss_masked(ctx_results, opt_ctx->outputs, opt_ctx->labels, opt_ctx->masks); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy_masked"); + if (opt_ctx->opt_period > 1) { + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy_masked_scaled"); + } + opt_ctx->loss_per_datapoint = true; + break; + } case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); ggml_set_input(opt_ctx->labels); @@ -433,16 +544,25 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) { ggml_set_loss(opt_ctx->loss); ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss); - if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) { + if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY || opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED) { opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs); ggml_set_name(opt_ctx->pred, "pred"); ggml_set_output(opt_ctx->pred); ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred); - opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels)); - ggml_set_name(opt_ctx->ncorrect, "ncorrect"); - ggml_set_output(opt_ctx->ncorrect); - ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect); + if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED && opt_ctx->masks != nullptr) { + // For instruction fine-tuning with masks, use masked accuracy calculation + struct ggml_tensor * labels_argmax = ggml_argmax(ctx_results, opt_ctx->labels); + opt_ctx->ncorrect = ggml_count_equal_masked(ctx_results, opt_ctx->pred, labels_argmax, opt_ctx->masks); + ggml_set_name(opt_ctx->ncorrect, "ncorrect_masked"); + ggml_set_output(opt_ctx->ncorrect); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect); + } else { + opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels)); + ggml_set_name(opt_ctx->ncorrect, "ncorrect"); + ggml_set_output(opt_ctx->ncorrect); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect); + } } if (opt_ctx->buf_static) { @@ -617,6 +737,10 @@ struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) { return opt_ctx->labels; } +struct ggml_tensor * ggml_opt_masks(ggml_opt_context_t opt_ctx) { + return opt_ctx->masks; +} + struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) { return opt_ctx->loss; } @@ -677,6 +801,7 @@ void ggml_opt_result_reset(ggml_opt_result_t result) { result->loss.clear(); result->pred.clear(); result->ncorrect = 0; + result->nmasked = 0; } void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) { @@ -725,14 +850,15 @@ void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) { } void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) { - *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN; + int64_t denominator = (result->nmasked > 0) ? result->nmasked : result->ndata; + *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(denominator) : NAN; if (!unc) { return; } - *unc = result->ncorrect >= 0 && result->ndata >= 2 ? - sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN; + *unc = result->ncorrect >= 0 && denominator >= 2 ? + sqrt((*accuracy) * (1.0 - (*accuracy)) / double(denominator - 1)) : NAN; } // ====== Computation ====== @@ -902,6 +1028,24 @@ void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) { int64_t ncorrect; ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect)); result->ncorrect += ncorrect; + + if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED && opt_ctx->masks) { + int64_t total_valid = 0; + const int64_t nr = opt_ctx->masks->ne[1]; + + const size_t mask_size = ggml_nbytes(opt_ctx->masks); + std::vector mask_data(mask_size / sizeof(float)); + ggml_backend_tensor_get(opt_ctx->masks, mask_data.data(), 0, mask_size); + + for (int64_t i1 = 0; i1 < nr; i1++) { + const size_t idx = i1 * (opt_ctx->masks->ne[0]); + const float mask_value = mask_data[idx]; + if (mask_value > 0.5f) { + total_valid++; + } + } + result->nmasked += total_valid; + } } // ====== High-Level Functions ====== diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a61f912f933..0431f2328a6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -544,6 +544,8 @@ struct vk_device_struct { vk_pipeline pipeline_geglu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_cross_entropy_loss_back_f32; + vk_pipeline pipeline_cross_entropy_loss_masked_back_f32; + vk_pipeline pipeline_count_equal_masked_i32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_back_f32; @@ -3403,6 +3405,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_masked_back_f32, "cross_entropy_loss_masked_back_f32", cross_entropy_loss_masked_back_f32_len, cross_entropy_loss_masked_back_f32_data, "main", 5, sizeof(vk_op_push_constants), {1, 1, 1}, { 32 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_count_equal_masked_i32, "count_equal_masked_i32", count_equal_masked_i32_len, count_equal_masked_i32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -7747,6 +7751,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cross_entropy_loss_back_f32; } return nullptr; + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cross_entropy_loss_masked_back_f32; + } + return nullptr; + case GGML_OP_COUNT_EQUAL_MASKED: + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && src2->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I64) { + return ctx->device->pipeline_count_equal_masked_i32; + } + return nullptr; case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); @@ -8242,6 +8256,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK: + { + const uint32_t nr = ggml_nrows(src1); + elements = { nr, 1, 1 }; + } break; + case GGML_OP_COUNT_EQUAL_MASKED: + { + const uint32_t n_elements = ggml_nelements(src0); + const uint32_t chunk_size = 512; + elements = { CEIL_DIV(n_elements, chunk_size), 1, 1 }; + } break; case GGML_OP_RMS_NORM: if (ctx->do_add_rms_partials) { // Run one element per thread, 128 threads per workgroup @@ -9202,6 +9227,124 @@ static void ggml_vk_cross_entropy_loss_back(ggml_backend_vk_context * ctx, vk_co } +static void ggml_vk_op_f32_cross_entropy_loss_masked_back(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * grad = dst->src[0]; // gradient of forward pass output + const ggml_tensor * logits = dst->src[1]; // logits + const ggml_tensor * labels = dst->src[2]; // targets + const ggml_tensor * mask = dst->src[3]; // mask + + GGML_ASSERT(grad->type == GGML_TYPE_F32); + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(labels->type == GGML_TYPE_F32); + GGML_ASSERT(mask->type == GGML_TYPE_F32); + GGML_ASSERT(dst->buffer != nullptr); + GGML_ASSERT(ggml_is_contiguous(grad)); + GGML_ASSERT(ggml_is_contiguous(logits)); + GGML_ASSERT(ggml_is_contiguous(labels)); + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(ggml_are_same_shape(logits, labels)); + GGML_ASSERT(ggml_are_same_shape(logits, dst)); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, logits, labels, mask, dst, GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + // Get buffer contexts + ggml_backend_vk_buffer_context * grad_buf_ctx = (ggml_backend_vk_buffer_context *)grad->buffer->context; + ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context; + ggml_backend_vk_buffer_context * labels_buf_ctx = (ggml_backend_vk_buffer_context *)labels->buffer->context; + ggml_backend_vk_buffer_context * mask_buf_ctx = (ggml_backend_vk_buffer_context *)mask->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + ggml_vk_sync_buffers(ctx, subctx); + + vk_buffer d_grad = nullptr, d_logits = nullptr, d_labels = nullptr, d_mask = nullptr, d_dst = nullptr; + size_t grad_offset = 0, logits_offset = 0, labels_offset = 0, mask_offset = 0, dst_offset = 0; + bool grad_uma = false, logits_uma = false, labels_uma = false, mask_uma = false, dst_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, grad->data, d_grad, grad_offset); + ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_offset); + ggml_vk_host_get(ctx->device, labels->data, d_labels, labels_offset); + ggml_vk_host_get(ctx->device, mask->data, d_mask, mask_offset); + ggml_vk_host_get(ctx->device, dst->data, d_dst, dst_offset); + + grad_uma = d_grad != nullptr; + logits_uma = d_logits != nullptr; + labels_uma = d_labels != nullptr; + mask_uma = d_mask != nullptr; + dst_uma = d_dst != nullptr; + } + + if (!grad_uma) { + d_grad = grad_buf_ctx->dev_buffer; + grad_offset = vk_tensor_offset(grad) + grad->view_offs; + } + if (!logits_uma) { + d_logits = logits_buf_ctx->dev_buffer; + logits_offset = vk_tensor_offset(logits) + logits->view_offs; + } + if (!labels_uma) { + d_labels = labels_buf_ctx->dev_buffer; + labels_offset = vk_tensor_offset(labels) + labels->view_offs; + } + if (!mask_uma) { + d_mask = mask_buf_ctx->dev_buffer; + mask_offset = vk_tensor_offset(mask) + mask->view_offs; + } + if (!dst_uma) { + d_dst = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + const uint64_t grad_size = ggml_nbytes(grad); + const uint64_t logits_size = ggml_nbytes(logits); + const uint64_t labels_size = ggml_nbytes(labels); + const uint64_t mask_size = ggml_nbytes(mask); + const uint64_t dst_size = ggml_nbytes(dst); + + std::array elements = { (uint32_t)ggml_nrows(logits), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_grad, grad_offset, grad_size }, + vk_subbuffer{ d_logits, logits_offset, logits_size }, + vk_subbuffer{ d_labels, labels_offset, labels_size }, + vk_subbuffer{ d_mask, mask_offset, mask_size }, + vk_subbuffer{ d_dst, dst_offset, dst_size }, + }, pc, elements); +} + +static void ggml_vk_cross_entropy_loss_masked_back(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const ggml_tensor * logits = dst->src[1]; + + const int64_t nclasses = logits->ne[0]; + const int64_t nrows = ggml_nrows(logits); + + float upstream_grad = 1.0f; + ggml_vk_op_f32_cross_entropy_loss_masked_back(ctx, subctx, dst, { + (uint32_t)nclasses, + (uint32_t)nrows, + upstream_grad, + 0.0f + }, dryrun); +} + +static void ggml_vk_count_equal_masked(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * predictions, const ggml_tensor * targets, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { + const int64_t n_elements = ggml_nelements(predictions); + const int64_t vocab_size = mask->ne[0]; + + ggml_vk_op_f32(ctx, subctx, predictions, targets, mask, dst, GGML_OP_COUNT_EQUAL_MASKED, { + (uint32_t)n_elements, + (uint32_t)vocab_size, + 0.0f, + 0.0f + }, dryrun); +} + static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; @@ -10616,6 +10759,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK: + case GGML_OP_COUNT_EQUAL_MASKED: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: @@ -10967,7 +11112,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_vk_cross_entropy_loss_back(ctx, compute_ctx, src0, src1, src2, node, dryrun); - + break; + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK: + ggml_vk_cross_entropy_loss_masked_back(ctx, compute_ctx, node, dryrun); + break; + case GGML_OP_COUNT_EQUAL_MASKED: + ggml_vk_count_equal_masked(ctx, compute_ctx, src0, src1, src2, node, dryrun); break; case GGML_OP_SOFT_MAX: ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); @@ -11153,6 +11303,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK: + case GGML_OP_COUNT_EQUAL_MASKED: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: @@ -12630,6 +12782,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_is_contiguous(op->src[1]) && ggml_is_contiguous(op)); } + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_COUNT_EQUAL_MASKED: + return op->src[0]->type == GGML_TYPE_I32 && op->src[1]->type == GGML_TYPE_I32 && op->src[2]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_I64; default: return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal_masked.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal_masked.comp new file mode 100644 index 00000000000..7b5c25a23e0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal_masked.comp @@ -0,0 +1,46 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_atomic_int64 : enable + +#include "types.comp" +#include "generic_head.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; // predictions +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; // targets +layout (binding = 2) readonly buffer M {C_TYPE data_m[];}; // masks (float, 1.0 for positions to count, 0.0 to skip) +layout (binding = 3) buffer D {D_TYPE data_d[];}; // output count + +const uint CHUNK_SIZE = 512; + +void main() { + const uint base = gl_WorkGroupID.x * CHUNK_SIZE; + const uint col = gl_LocalInvocationID.x; + + if (gl_WorkGroupID.x == 0 && gl_LocalInvocationID.x == 0) { + data_d[0] = D_TYPE(0); + } + + barrier(); + + uint count = 0; + [[unroll]] + for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) { + const uint idx = base + i + col; + if (idx >= p.KX) { + break; + } + + uint position = idx; + uint mask_offset = 0 + position * p.KY + 0; + float mask_value = data_m[mask_offset]; + + if (mask_value > 0.5 && data_a[idx] == data_b[idx]) { + count += 1; + } + } + + atomicAdd(data_d[0], D_TYPE(count)); +} \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_masked_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_masked_back.comp new file mode 100644 index 00000000000..803855b0fdb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_masked_back.comp @@ -0,0 +1,115 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "generic_head.comp" +#include "types.comp" + +#define FLOAT_TYPE float + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer C {C_TYPE data_c[];}; +layout (binding = 3) readonly buffer D {E_TYPE data_d[];}; +layout (binding = 4) writeonly buffer E {D_TYPE data_e[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +void main() { + const uint nclasses = p.KX; + const uint nrows = p.KY; + + const uint row = gl_WorkGroupID.x; + if (row >= nrows) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint warp_size = gl_WorkGroupSize.x; + + const uint logits_offset = row * nclasses; + const uint labels_offset = row * nclasses; + const uint dst_offset = row * nclasses; + + const float mask_value = data_d[row * nclasses + 0]; + + if (mask_value <= 0.5) { + for (uint i = tid; i < nclasses; i += warp_size) { + data_e[dst_offset + i] = D_TYPE(0.0); + } + return; + } + + FLOAT_TYPE d_by_valid; + if (tid == 0) { + FLOAT_TYPE valid_tokens = FLOAT_TYPE(0.0); + for (uint r = 0; r < nrows; r++) { + const float mask_val = data_d[r * nclasses + 0]; + if (mask_val > 0.5) { + valid_tokens += FLOAT_TYPE(1.0); + } + } + + const FLOAT_TYPE upstream_grad = FLOAT_TYPE(p.param1); + d_by_valid = valid_tokens > 0.0 ? upstream_grad / valid_tokens : FLOAT_TYPE(0.0); + + vals[0] = d_by_valid; + } + barrier(); + + d_by_valid = vals[0]; + + FLOAT_TYPE thread_max = FLOAT_TYPE(uintBitsToFloat(0xFF800000)); // -INFINITY + for (uint i = tid; i < nclasses; i += warp_size) { + FLOAT_TYPE val = FLOAT_TYPE(data_b[logits_offset + i]); + thread_max = max(thread_max, val); + } + + vals[tid] = thread_max; + barrier(); + + [[unroll]] + for (uint s = warp_size / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + const FLOAT_TYPE row_max = vals[0]; + barrier(); + + FLOAT_TYPE thread_sum = FLOAT_TYPE(0.0); + for (uint i = tid; i < nclasses; i += warp_size) { + FLOAT_TYPE val = FLOAT_TYPE(data_b[logits_offset + i]); + thread_sum += exp(val - row_max); + } + + vals[tid] = thread_sum; + barrier(); + + [[unroll]] + for (uint s = warp_size / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE row_sum = vals[0]; + const FLOAT_TYPE sm_scale = FLOAT_TYPE(1.0) / row_sum; + barrier(); + + for (uint i = tid; i < nclasses; i += warp_size) { + FLOAT_TYPE logit = FLOAT_TYPE(data_b[logits_offset + i]); + FLOAT_TYPE softmax_val = exp(logit - row_max) * sm_scale; + + FLOAT_TYPE label = FLOAT_TYPE(data_c[labels_offset + i]); + FLOAT_TYPE gradient = (softmax_val - label) * d_by_valid; + + data_e[dst_offset + i] = D_TYPE(gradient); + } +} \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 61ebc6c61c6..43251cdb9ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -693,6 +693,8 @@ void process_shaders() { string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cross_entropy_loss_back_f32", "cross_entropy_loss_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("cross_entropy_loss_masked_back_f32", "cross_entropy_loss_masked_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}, {"E_TYPE", "float"}}); + string_to_spv("count_equal_masked_i32", "count_equal_masked.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"C_TYPE", "float"}, {"D_TYPE", "int64_t"}}); string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 40524758bfd..3a279e95133 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -938,6 +938,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "MEAN", "ARGMAX", "COUNT_EQUAL", + "COUNT_EQUAL_MASKED", "REPEAT", "REPEAT_BACK", "CONCAT", @@ -1014,13 +1015,15 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", + "CROSS_ENTROPY_LOSS_MASKED", + "CROSS_ENTROPY_LOSS_MASKED_BACK", "OPT_STEP_ADAMW", "OPT_STEP_SGD", "GLU", }; -static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); +static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1125,7 +1128,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); +static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2386,6 +2389,26 @@ struct ggml_tensor * ggml_count_equal( return result; } +// ggml_count_equal_masked + +struct ggml_tensor * ggml_count_equal_masked( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(c->type == GGML_TYPE_F32); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1); + + result->op = GGML_OP_COUNT_EQUAL_MASKED; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + // ggml_repeat struct ggml_tensor * ggml_repeat( @@ -5737,6 +5760,9 @@ struct ggml_tensor * ggml_cross_entropy_loss( result->op = GGML_OP_CROSS_ENTROPY_LOSS; result->src[0] = a; result->src[1] = b; + + // Initialize op_params to 0 (no masking) + *(int32_t *)(result->op_params) = 0; return result; } @@ -5761,6 +5787,51 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( return result; } +// ggml_cross_entropy_loss_masked + +struct ggml_tensor * ggml_cross_entropy_loss_masked( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, c)); + GGML_ASSERT(c->type == GGML_TYPE_F32); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS_MASKED; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +// ggml_cross_entropy_loss_masked_back + +struct ggml_tensor * ggml_cross_entropy_loss_masked_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + struct ggml_tensor * d) { + GGML_ASSERT(ggml_is_scalar(d)); + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, c)); + GGML_ASSERT(c->type == GGML_TYPE_F32); + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS_MASKED_BACK; + result->src[0] = d; + result->src[1] = a; + result->src[2] = b; + result->src[3] = c; + + return result; +} + // opt_step_adamw struct ggml_tensor * ggml_opt_step_adamw( @@ -6396,6 +6467,13 @@ static void ggml_compute_backward( } GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); } break; + case GGML_OP_CROSS_ENTROPY_LOSS_MASKED: { + if (src0_needs_grads) { + struct ggml_tensor * mask_tensor = tensor->src[2]; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_masked_back(ctx, src0, src1, mask_tensor, grad)); + } + GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); + } break; case GGML_OP_GLU: { switch (ggml_get_glu_op(tensor)) { case GGML_GLU_OP_SWIGLU: { diff --git a/include/llama.h b/include/llama.h index 82144182bc0..b3f56c30db8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1388,6 +1388,8 @@ extern "C" { // Optional checkpoint loading const char * checkpoint_path; // path to checkpoint file to load optimizer state from (nullptr = don't load) bool load_optimizer_state; // whether to load optimizer state from checkpoint_path + + bool assistant_loss_only; }; LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a661cc2e1ab..24637c3d16e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -696,7 +696,7 @@ void llama_context::set_adapter_lora( llama_adapter_lora * adapter, float scale) { LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); - + loras[adapter] = scale; } @@ -2056,7 +2056,10 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); GGML_ASSERT(n_batch % n_ubatch == 0); - ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); + opt_loss_type = lopt_params.assistant_loss_only ? + GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED : GGML_OPT_LOSS_TYPE_CROSS_ENTROPY; + + ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), opt_loss_type); opt_params.opt_period = n_batch / n_ubatch; opt_params.get_opt_pars = lopt_params.get_opt_pars; opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; @@ -2120,6 +2123,7 @@ void llama_context::opt_epoch_iter( ggml_opt_result_t result, const std::vector & tokens, const std::vector & labels_sparse, + const std::vector & masks_sparse, llama_batch & batch, ggml_opt_epoch_callback callback, bool train, @@ -2130,7 +2134,7 @@ void llama_context::opt_epoch_iter( const uint32_t n_ctx = llama_model_n_ctx_train(&model); const uint32_t n_batch = std::min(this->n_batch(), n_ctx); const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); - + memory->clear(true); for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { @@ -2142,7 +2146,7 @@ void llama_context::opt_epoch_iter( batch.seq_id [pos_batch][0] = 0; batch.logits [pos_batch] = true; } - + if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return; @@ -2215,14 +2219,42 @@ void llama_context::opt_epoch_iter( res->set_inputs(&ubatch); { struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); + struct ggml_tensor * masks = (opt_loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED && ggml_opt_dataset_masks(dataset)) ? ggml_opt_masks(opt_ctx) : nullptr; // Only get masks if using masked loss GGML_ASSERT(labels->ne[1] == n_ubatch); + GGML_ASSERT(labels->type == GGML_TYPE_F32); + if (masks) { + GGML_ASSERT(masks->ne[1] == n_ubatch); + GGML_ASSERT(masks->type == GGML_TYPE_F32); + } ggml_set_zero(labels); const float onef = 1.0f; for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; + const uint32_t imask = pos_ctx + pos_batch + pos_ubatch; + if (masks && imask < masks_sparse.size() && masks_sparse[imask] == 0) { + continue; + } + GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); } + + if (masks) { + if (pos_batch == 0) { + ggml_set_zero(masks); + } + const float onef = 1.0f; + for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { + const uint32_t imask = pos_ctx + pos_batch + pos_ubatch; + + if (imask < masks_sparse.size() && masks_sparse[imask] == 1) { + const size_t offset = (pos_ubatch * masks->ne[0] + 0) * sizeof(float); + ggml_backend_tensor_set(masks, &onef, offset, sizeof(float)); + } + } + + ggml_backend_sched_synchronize(get_sched()); + } } ggml_opt_eval(opt_ctx, result); if (callback) { @@ -2256,6 +2288,7 @@ void llama_context::opt_epoch( struct llama_batch batch = llama_batch_init(n_batch, 0, 1); std::vector tokens(n_ctx); std::vector labels_sparse(n_ctx); + std::vector masks_sparse(n_ctx); int64_t idata = (resume_from_batch >= 0) ? resume_from_batch + 1 : 0; @@ -2265,8 +2298,13 @@ void llama_context::opt_epoch( constexpr bool train = true; const int64_t idata_in_loop = idata*ubatch_per_ctx; - ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); - opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, + if (opt_loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED && ggml_opt_dataset_masks(dataset)) { + ggml_opt_dataset_get_batch_host_with_masks(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), masks_sparse.data(), idata); + + } else { + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + } + opt_epoch_iter(dataset, result_train, tokens, labels_sparse, masks_sparse, batch, callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); } @@ -2276,8 +2314,12 @@ void llama_context::opt_epoch( constexpr bool train = false; const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; - ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); - opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, + if (opt_loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED && ggml_opt_dataset_masks(dataset)) { + ggml_opt_dataset_get_batch_host_with_masks(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), masks_sparse.data(), idata); + } else { + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + } + opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, masks_sparse, batch, callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); } diff --git a/src/llama-context.h b/src/llama-context.h index bbb8ab32272..3352fea4a33 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -172,6 +172,7 @@ struct llama_context { ggml_opt_result_t result, const std::vector & tokens, const std::vector & labels_sparse, + const std::vector & masks_sparse, llama_batch & batch, ggml_opt_epoch_callback callback, bool train, @@ -275,6 +276,7 @@ struct llama_context { std::string pending_optimizer_checkpoint_path; bool should_load_optimizer_tensors = false; bool optimizer_tensors_loaded = false; + ggml_opt_loss_type opt_loss_type = GGML_OPT_LOSS_TYPE_CROSS_ENTROPY; ggml_threadpool_t threadpool = nullptr; ggml_threadpool_t threadpool_batch = nullptr; From f1bed73b7e4a8b07189b7deb4e078b97b874ad9a Mon Sep 17 00:00:00 2001 From: Marcus Edel Date: Mon, 27 Oct 2025 16:40:49 -0400 Subject: [PATCH 2/5] Add learning rate scheduler: constant (default), linear, and cosine. Signed-off-by: Marcus Edel --- examples/training/finetune-lora.cpp | 316 +++++++++++++++++++++++----- 1 file changed, 261 insertions(+), 55 deletions(-) diff --git a/examples/training/finetune-lora.cpp b/examples/training/finetune-lora.cpp index dc920f36ab1..f23287bca20 100644 --- a/examples/training/finetune-lora.cpp +++ b/examples/training/finetune-lora.cpp @@ -4,7 +4,12 @@ #include "llama.h" #include "ggml-backend.h" +#include +#include +#include +#include #include +#include #include #include #include @@ -17,6 +22,115 @@ struct checkpoint_callback_data; static checkpoint_callback_data* g_checkpoint_data = nullptr; +enum class lora_lr_schedule_type : std::uint8_t { + CONSTANT, + COSINE, + LINEAR, +}; + +struct lora_lr_scheduler_state { + lora_lr_schedule_type schedule = lora_lr_schedule_type::CONSTANT; + float lr_init = 1e-5f; + float lr_min = 0.0f; + float weight_decay = 0.0f; + int64_t total_steps = 0; + int64_t current_step = 0; + float last_lr = 0.0f; +}; + +static bool lora_lr_scheduler_type_from_string(const std::string & name, lora_lr_schedule_type & out) { + auto equals = [](const std::string & lhs, const char * rhs) { + const size_t rhs_len = std::strlen(rhs); + if (lhs.size() != rhs_len) { + return false; + } + for (size_t i = 0; i < rhs_len; ++i) { + if (std::tolower(static_cast(lhs[i])) != + std::tolower(static_cast(rhs[i]))) { + return false; + } + } + return true; + }; + + if (equals(name, "constant")) { + out = lora_lr_schedule_type::CONSTANT; + return true; + } + if (equals(name, "cosine")) { + out = lora_lr_schedule_type::COSINE; + return true; + } + if (equals(name, "linear")) { + out = lora_lr_schedule_type::LINEAR; + return true; + } + return false; +} + +static const char * lora_lr_scheduler_type_to_cstr(lora_lr_schedule_type type) { + switch (type) { + case lora_lr_schedule_type::LINEAR: return "linear"; + case lora_lr_schedule_type::COSINE: return "cosine"; + case lora_lr_schedule_type::CONSTANT: return "constant"; + } + return "constant"; +} + +static float lora_scheduler_lr_for_step(const lora_lr_scheduler_state & state, int64_t step) { + + if (state.total_steps <= 0) { + return std::max(state.lr_init, 0.0f); + } + + const int64_t clamped_step = std::min(std::max(step, 0), state.total_steps); + float lr = state.lr_init; + + switch (state.schedule) { + case lora_lr_schedule_type::CONSTANT: + lr = state.lr_init; + break; + case lora_lr_schedule_type::COSINE: { + constexpr float kPi = 3.14159265358979323846f; + const float progress = static_cast(clamped_step) / static_cast(state.total_steps); + const float cosine = 0.5f * (1.0f + std::cos(progress * kPi)); + lr = state.lr_min + (state.lr_init - state.lr_min) * cosine; + break; + } + case lora_lr_schedule_type::LINEAR: { + const float progress = static_cast(clamped_step) / static_cast(state.total_steps); + lr = state.lr_init + (state.lr_min - state.lr_init) * progress; + break; + } + } + + return std::max(lr, 0.0f); +} + +static struct ggml_opt_optimizer_params lora_scheduler_get_optimizer_params(void * userdata) { + auto * scheduler = static_cast(userdata); + struct ggml_opt_optimizer_params params = ggml_opt_get_default_optimizer_params(nullptr); + + if (!scheduler) { + return params; + } + + const float lr = lora_scheduler_lr_for_step(*scheduler, scheduler->current_step); + scheduler->last_lr = lr; + + params.adamw.alpha = lr; + params.adamw.wd = scheduler->weight_decay; + + params.sgd.alpha = lr; + params.sgd.wd = scheduler->weight_decay; + + if (scheduler->current_step < scheduler->total_steps) { + scheduler->current_step++; + } + + return params; +} + static uint32_t parse_lora_modules(const std::string& modules_str) { if (modules_str.empty()) { return LLAMA_LORA_TARGET_ATTN_Q | LLAMA_LORA_TARGET_ATTN_K | LLAMA_LORA_TARGET_ATTN_V | LLAMA_LORA_TARGET_ATTN_O; @@ -135,6 +249,10 @@ static void print_lora_usage() { printf(" --assistant-loss-only Use JSON dataset format with masked loss (ChatML/conversation format)\n"); printf(" Only computes loss on assistant responses, not system/user prompts\n"); printf(" --chat-template PATH Optional Jinja chat template to render JSON dataset (matches HF apply_chat_template)\n"); + printf(" --learning-rate F AdamW learning rate (default: 1e-5)\n"); + printf(" --weight-decay F AdamW weight decay (default: 1e-2)\n"); + printf(" --lr-scheduler TYPE Learning rate scheduler: constant, cosine, linear (default: constant)\n"); + printf(" --lr-min F Minimum LR for cosine/linear schedulers (default: 0)\n"); printf("\nCheckpointing Options:\n"); printf(" --checkpoint-save-steps N Save checkpoint every N training steps (default: 100)\n"); printf(" --checkpoint-save-dir PATH Directory for checkpoints (default: ./checkpoints)\n"); @@ -167,10 +285,10 @@ static std::string find_latest_checkpoint(const std::string& checkpoint_dir) { if (!std::filesystem::exists(checkpoint_dir)) { return ""; } - + std::string latest_checkpoint; int64_t latest_step = -1; - + for (const auto& entry : std::filesystem::directory_iterator(checkpoint_dir)) { if (entry.is_directory()) { std::string dirname = entry.path().filename().string(); @@ -188,7 +306,7 @@ static std::string find_latest_checkpoint(const std::string& checkpoint_dir) { } } } - + return latest_checkpoint; } @@ -199,12 +317,12 @@ static bool save_checkpoint(llama_context* ctx, llama_adapter_lora* adapter, co return false; } } - + if (!llama_lora_save_checkpoint(adapter, checkpoint_dir.c_str(), llama_get_model(ctx), ctx)) { LOG_ERR("Failed to save LoRA checkpoint\n"); return false; } - + std::string meta_path = checkpoint_dir + "/metadata.json"; std::ofstream meta_file(meta_path); if (meta_file.is_open()) { @@ -217,21 +335,21 @@ static bool save_checkpoint(llama_context* ctx, llama_adapter_lora* adapter, co LOG_ERR("Failed to save checkpoint metadata\n"); return false; } - + LOG_INF("Checkpoint saved successfully to %s\n", checkpoint_dir.c_str()); return true; } static bool validate_checkpoint_metadata(const std::string& checkpoint_path, checkpoint_metadata& metadata) { std::string checkpoint_dir = checkpoint_path; - + if (!std::filesystem::exists(checkpoint_dir)) { LOG_ERR("Checkpoint directory does not exist: %s\n", checkpoint_dir.c_str()); return false; } - + LOG_INF("Loading checkpoint from: %s\n", checkpoint_dir.c_str()); - + std::string meta_path = checkpoint_dir + "/metadata.json"; if (std::filesystem::exists(meta_path)) { std::ifstream meta_file(meta_path); @@ -242,7 +360,7 @@ static bool validate_checkpoint_metadata(const std::string& checkpoint_path, che if (eq_pos != std::string::npos) { std::string key = line.substr(0, eq_pos); std::string value = line.substr(eq_pos + 1); - + if (key == "epoch") { metadata.epoch = std::stoi(value); } else if (key == "lora_rank") { @@ -263,7 +381,7 @@ static bool validate_checkpoint_metadata(const std::string& checkpoint_path, che LOG_ERR("Checkpoint metadata file not found: %s\n", meta_path.c_str()); return false; } - + LOG_INF("Checkpoint loaded successfully\n"); return true; } @@ -281,6 +399,7 @@ struct checkpoint_callback_data { float lora_alpha; uint32_t target_modules; float learning_rate; + lora_lr_scheduler_state * lr_scheduler; std::string model_path; std::string dataset_path; }; @@ -294,42 +413,48 @@ static void checkpoint_progress_callback( int64_t ibatch_max, int64_t t_start_us) { ggml_opt_epoch_callback_progress_bar(train, opt_ctx, dataset, result, ibatch, ibatch_max, t_start_us); - - if (!train) return; - + + if (!train) { + return; + } + checkpoint_callback_data* cb_data = g_checkpoint_data; - + if (!cb_data) { LOG_ERR("Checkpoint callback data is null!\n"); return; } - + + if (cb_data->lr_scheduler) { + cb_data->learning_rate = lora_scheduler_lr_for_step(*cb_data->lr_scheduler, cb_data->lr_scheduler->current_step); + } + if (cb_data->checkpoint_save_steps <= 0) { return; } - + cb_data->global_step++; - + if (cb_data->global_step % cb_data->checkpoint_save_steps == 0) { if (!cb_data->ctx) { LOG_ERR("Context is null in checkpoint callback!\n"); return; } - + if (!cb_data->adapter) { LOG_ERR("LoRA adapter is null in checkpoint callback!\n"); return; } - + checkpoint_metadata meta = { /*epoch =*/ cb_data->current_epoch, /*lora_rank =*/ cb_data->lora_rank, /*lora_alpha =*/ cb_data->lora_alpha, /*target_modules =*/ cb_data->target_modules, }; - + std::string checkpoint_path = get_checkpoint_filename(cb_data->checkpoint_save_dir, cb_data->global_step); - + if (!save_checkpoint(cb_data->ctx, cb_data->adapter, meta, checkpoint_path)) { LOG_ERR("Failed to save checkpoint at step %lld\n", (long long)cb_data->global_step); } @@ -341,9 +466,13 @@ struct finetune_params { float lora_alpha = 16.0f; std::string lora_modules_str; std::string output_adapter_path; - + int32_t num_epochs = 1; - + float learning_rate = 1e-5f; + float lr_min = 0.0f; + float weight_decay = 0.01f; + std::string lr_scheduler = "constant"; + int32_t checkpoint_save_steps = 100; std::string checkpoint_save_dir = "./checkpoints"; std::string resume_from_checkpoint; @@ -396,6 +525,22 @@ static bool parse_finetune_args(int& argc, char** argv, finetune_params& ft_para ft_params.num_epochs = std::atoi(argv[i + 1]); remove_arg_pair(i); i--; + } else if (strcmp(argv[i], "--learning-rate") == 0 && i + 1 < argc) { + ft_params.learning_rate = std::atof(argv[i + 1]); + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--weight-decay") == 0 && i + 1 < argc) { + ft_params.weight_decay = std::atof(argv[i + 1]); + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--lr-scheduler") == 0 && i + 1 < argc) { + ft_params.lr_scheduler = argv[i + 1]; + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--lr-min") == 0 && i + 1 < argc) { + ft_params.lr_min = std::atof(argv[i + 1]); + remove_arg_pair(i); + i--; } else if (strcmp(argv[i], "--checkpoint-save-steps") == 0 && i + 1 < argc) { ft_params.checkpoint_save_steps = std::atoi(argv[i + 1]); remove_arg_pair(i); @@ -427,7 +572,7 @@ static bool parse_finetune_args(int& argc, char** argv, finetune_params& ft_para print_lora_usage(); } } - + return true; } @@ -442,9 +587,39 @@ int main(int argc, char ** argv) { return 1; } + lora_lr_schedule_type scheduler_type; + if (!lora_lr_scheduler_type_from_string(ft_params.lr_scheduler, scheduler_type)) { + LOG_ERR("Unknown learning rate scheduler: %s (expected: constant, cosine, linear)\n", ft_params.lr_scheduler.c_str()); + return 1; + } + + if (ft_params.num_epochs <= 0) { + LOG_ERR("Number of epochs must be > 0, got %d\n", ft_params.num_epochs); + return 1; + } + if (ft_params.learning_rate <= 0.0f) { + LOG_ERR("Learning rate must be > 0, got %.4e\n", ft_params.learning_rate); + return 1; + } + if (ft_params.weight_decay < 0.0f) { + LOG_ERR("Weight decay must be >= 0, got %.4e\n", ft_params.weight_decay); + return 1; + } + if (ft_params.lr_min < 0.0f) { + LOG_ERR("Minimum learning rate must be >= 0, got %.4e\n", ft_params.lr_min); + return 1; + } + const bool scheduler_uses_lr_min = scheduler_type == lora_lr_schedule_type::COSINE || + scheduler_type == lora_lr_schedule_type::LINEAR; + if (scheduler_uses_lr_min && ft_params.lr_min > ft_params.learning_rate) { + LOG_ERR("For %s scheduler lr-min (%.4e) cannot exceed learning-rate (%.4e)\n", + lora_lr_scheduler_type_to_cstr(scheduler_type), ft_params.lr_min, ft_params.learning_rate); + return 1; + } + LOG_INF("Using LoRA parameters: rank=%d, alpha=%.1f\n", ft_params.lora_rank, ft_params.lora_alpha); LOG_INF("Training for %d epochs\n", ft_params.num_epochs); - + // Handle checkpoint auto-resume before model initialization if (ft_params.auto_resume && ft_params.resume_from_checkpoint.empty()) { std::string latest_checkpoint = find_latest_checkpoint(ft_params.checkpoint_save_dir); @@ -453,16 +628,16 @@ int main(int argc, char ** argv) { LOG_INF("Auto-resume: found checkpoint %s\n", ft_params.resume_from_checkpoint.c_str()); } } - + if (!ft_params.resume_from_checkpoint.empty()) { params.warmup = false; } - + // Load checkpoint LoRA adapter from directory structure (model.gguf) if (!ft_params.resume_from_checkpoint.empty()) { std::filesystem::path checkpoint_dir(ft_params.resume_from_checkpoint); std::filesystem::path model_path = checkpoint_dir / "model.gguf"; - + LOG_INF("Loading checkpoint LoRA adapter: %s\n", model_path.c_str()); common_adapter_lora_info lora_adapter; lora_adapter.path = model_path.string(); @@ -542,7 +717,7 @@ int main(int argc, char ** argv) { (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_UP) ? "yes" : "no", (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_DOWN) ? "yes" : "no", (lora_params.target_modules & LLAMA_LORA_TARGET_OUTPUT) ? "yes" : "no"); - + LOG_INF("LoRA configuration: rank=%d, alpha=%.1f (scaling=%.3f)\n", lora_params.rank, lora_params.alpha, lora_params.alpha / lora_params.rank); @@ -556,7 +731,7 @@ int main(int argc, char ** argv) { constexpr float val_split = 0.05f; ggml_opt_dataset_t dataset; - + if (ft_params.assistant_loss_only) { LOG_INF("Using JSON dataset with chat template and assistant-only loss\n"); dataset = common_opt_sft_dataset_init(ctx.get(), params.prompt, llama_n_ctx(ctx.get())/2, ft_params.chat_template_path); @@ -565,29 +740,58 @@ int main(int argc, char ** argv) { LOG_INF("Using standard next-token prediction mode\n"); dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); } - + if (dataset == nullptr) { LOG_ERR("Failed to create dataset. Please check your input file and parameters.\n"); return 1; } + const int64_t total_datapoints = ggml_opt_dataset_ndata(dataset); + const int64_t idata_split = static_cast(total_datapoints * (1.0f - val_split)); + const int64_t training_batches_per_epoch = idata_split; + + if (training_batches_per_epoch <= 0) { + LOG_ERR("Training split is empty. Adjust --val-split or dataset size.\n"); + return 1; + } + + lora_lr_scheduler_state lr_scheduler; + lr_scheduler.schedule = scheduler_type; + lr_scheduler.lr_init = ft_params.learning_rate; + lr_scheduler.lr_min = (scheduler_type == lora_lr_schedule_type::CONSTANT) ? ft_params.learning_rate : ft_params.lr_min; + lr_scheduler.weight_decay = ft_params.weight_decay; + lr_scheduler.total_steps = std::max(1, static_cast(ft_params.num_epochs) * training_batches_per_epoch); + lr_scheduler.current_step = 0; + lr_scheduler.last_lr = lora_scheduler_lr_for_step(lr_scheduler, lr_scheduler.current_step); + + LOG_INF("Training split: datapoints=%lld, batches_per_epoch=%lld\n", + (long long) total_datapoints, (long long) training_batches_per_epoch); + LOG_INF("Optimizer: adamw scheduler=%s lr=%.4e wd=%.4e total_steps=%lld\n", + lora_lr_scheduler_type_to_cstr(lr_scheduler.schedule), lr_scheduler.lr_init, + lr_scheduler.weight_decay, (long long) lr_scheduler.total_steps); + if (lr_scheduler.schedule == lora_lr_schedule_type::COSINE) { + LOG_INF("Cosine scheduler: lr-min=%.4e\n", lr_scheduler.lr_min); + } else if (lr_scheduler.schedule == lora_lr_schedule_type::LINEAR) { + LOG_INF("Linear scheduler: lr-min=%.4e\n", lr_scheduler.lr_min); + } + int start_epoch = 0; int64_t start_step = 0; checkpoint_metadata checkpoint_meta = {}; bool checkpoint_loaded = false; - + if (!ft_params.resume_from_checkpoint.empty()) { if (validate_checkpoint_metadata(ft_params.resume_from_checkpoint, checkpoint_meta)) { start_epoch = checkpoint_meta.epoch; checkpoint_loaded = true; - + if (checkpoint_meta.lora_rank != ft_params.lora_rank) { - LOG_ERR("Checkpoint LoRA rank (%d) doesn't match current rank (%d). Use --resume-from to manually specify a compatible checkpoint.\n", + LOG_ERR("Checkpoint LoRA rank (%d) doesn't match current rank (%d). Use --resume-from to manually specify a compatible checkpoint.\n", checkpoint_meta.lora_rank, ft_params.lora_rank); return 1; } if (checkpoint_meta.lora_alpha != ft_params.lora_alpha) { - LOG_ERR("Checkpoint LoRA alpha (%.3f) doesn't match current alpha (%.3f)\n", + LOG_ERR("Checkpoint LoRA alpha (%.3f) doesn't match current alpha (%.3f)\n", checkpoint_meta.lora_alpha, ft_params.lora_alpha); return 1; } @@ -595,16 +799,12 @@ int main(int argc, char ** argv) { LOG_ERR("Checkpoint target_modules doesn't match current target_modules\n"); return 1; } - + } else { LOG_ERR("Failed to load checkpoint, starting from scratch\n"); } } - struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr); - optimizer_params.adamw.alpha = 1e-5f; // learning rate - optimizer_params.adamw.wd = 0.01f; - std::string optimizer_checkpoint_path; if (checkpoint_loaded && !ft_params.resume_from_checkpoint.empty()) { std::filesystem::path checkpoint_dir(ft_params.resume_from_checkpoint); @@ -615,31 +815,33 @@ int main(int argc, char ** argv) { /*n_ctx_train =*/ 0, /*param_filter =*/ llama_opt_param_filter_lora, /*param_filter_ud =*/ nullptr, - /*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params, - /*get_opt_pars_ud =*/ &optimizer_params, + /*get_opt_pars =*/ lora_scheduler_get_optimizer_params, + /*get_opt_pars_ud =*/ &lr_scheduler, /*optimizer_type =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW, /*checkpoint_path =*/ checkpoint_loaded ? optimizer_checkpoint_path.c_str() : nullptr, /*load_optimizer_state =*/ checkpoint_loaded, /*assistant_loss_only =*/ ft_params.assistant_loss_only, }; - + llama_opt_init(ctx.get(), model.get(), lopt_params); - + if (checkpoint_loaded) { start_step = llama_opt_get_iter(ctx.get()); } - + + lr_scheduler.current_step = std::min(start_step, lr_scheduler.total_steps); + lr_scheduler.last_lr = lora_scheduler_lr_for_step(lr_scheduler, lr_scheduler.current_step); + if (!trained_adapter) { LOG_ERR("No trained adapter available for checkpointing\n"); return 1; } - - const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split); - const int64_t training_batches_per_epoch = idata_split; if (start_step > 0) { int64_t completed_epochs = start_step / training_batches_per_epoch; start_epoch = (int)completed_epochs; + LOG_INF("Resuming training from global step %lld (lr=%.4e)\n", + (long long) start_step, lr_scheduler.last_lr); } checkpoint_callback_data cb_data = { @@ -653,7 +855,8 @@ int main(int argc, char ** argv) { /*lora_rank =*/ ft_params.lora_rank, /*lora_alpha =*/ ft_params.lora_alpha, /*target_modules =*/ target_modules, - /*learning_rate =*/ optimizer_params.adamw.alpha, + /*learning_rate =*/ lr_scheduler.last_lr, + /*lr_scheduler =*/ &lr_scheduler, /*model_path =*/ params.model.path, /*dataset_path =*/ params.prompt_file, }; @@ -663,17 +866,20 @@ int main(int argc, char ** argv) { ggml_opt_result_t result_eval = ggml_opt_result_init(); for (int epoch = start_epoch; epoch < ft_params.num_epochs; ++epoch) { - LOG_INF("Starting epoch %d (step %lld)\n", epoch, (long long)cb_data.global_step); + if (cb_data.lr_scheduler) { + cb_data.learning_rate = lora_scheduler_lr_for_step(*cb_data.lr_scheduler, cb_data.lr_scheduler->current_step); + } + LOG_INF("Starting epoch %d (step %lld, lr=%.4e)\n", epoch, (long long)cb_data.global_step, cb_data.learning_rate); cb_data.current_epoch = epoch; - + int64_t resume_batch = 0; if (start_step > 0 && epoch == start_epoch) { resume_batch = start_step % training_batches_per_epoch; } - - ggml_opt_epoch_callback train_callback = (ft_params.checkpoint_save_steps <= 0) ? + + ggml_opt_epoch_callback train_callback = (ft_params.checkpoint_save_steps <= 0) ? ggml_opt_epoch_callback_progress_bar : checkpoint_progress_callback; - ggml_opt_epoch_callback eval_callback = (ft_params.checkpoint_save_steps <= 0) ? + ggml_opt_epoch_callback eval_callback = (ft_params.checkpoint_save_steps <= 0) ? ggml_opt_epoch_callback_progress_bar : checkpoint_progress_callback; if (resume_batch > 0) { From 4674584e4be97eda6f6584403fe1161de9c49610 Mon Sep 17 00:00:00 2001 From: Marcus Edel Date: Mon, 27 Oct 2025 18:58:48 -0400 Subject: [PATCH 3/5] Add warmup-ratio parameter to match HF training. Signed-off-by: Marcus Edel --- examples/training/finetune-lora.cpp | 56 +++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/examples/training/finetune-lora.cpp b/examples/training/finetune-lora.cpp index f23287bca20..d42137776c0 100644 --- a/examples/training/finetune-lora.cpp +++ b/examples/training/finetune-lora.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -36,6 +37,8 @@ struct lora_lr_scheduler_state { int64_t total_steps = 0; int64_t current_step = 0; float last_lr = 0.0f; + float warmup_ratio = 0.0f; + int64_t warmup_steps = 0; }; static bool lora_lr_scheduler_type_from_string(const std::string & name, lora_lr_schedule_type & out) { @@ -84,6 +87,21 @@ static float lora_scheduler_lr_for_step(const lora_lr_scheduler_state & state, i } const int64_t clamped_step = std::min(std::max(step, 0), state.total_steps); + const int64_t warmup_steps = std::min(std::max(state.warmup_steps, 0), state.total_steps); + + if (warmup_steps > 0 && clamped_step < warmup_steps) { + const float warmup_progress = static_cast(clamped_step) / static_cast(warmup_steps); + const float lr = state.lr_init * warmup_progress; + return std::max(lr, 0.0f); + } + + const int64_t adjusted_step = clamped_step - warmup_steps; + int64_t remaining_steps = state.total_steps - warmup_steps; + if (remaining_steps <= 0) { + remaining_steps = 1; + } + + const float progress = std::min(static_cast(adjusted_step) / static_cast(remaining_steps), 1.0f); float lr = state.lr_init; switch (state.schedule) { @@ -92,13 +110,11 @@ static float lora_scheduler_lr_for_step(const lora_lr_scheduler_state & state, i break; case lora_lr_schedule_type::COSINE: { constexpr float kPi = 3.14159265358979323846f; - const float progress = static_cast(clamped_step) / static_cast(state.total_steps); const float cosine = 0.5f * (1.0f + std::cos(progress * kPi)); lr = state.lr_min + (state.lr_init - state.lr_min) * cosine; break; } case lora_lr_schedule_type::LINEAR: { - const float progress = static_cast(clamped_step) / static_cast(state.total_steps); lr = state.lr_init + (state.lr_min - state.lr_init) * progress; break; } @@ -253,6 +269,8 @@ static void print_lora_usage() { printf(" --weight-decay F AdamW weight decay (default: 1e-2)\n"); printf(" --lr-scheduler TYPE Learning rate scheduler: constant, cosine, linear (default: constant)\n"); printf(" --lr-min F Minimum LR for cosine/linear schedulers (default: 0)\n"); + printf(" --warmup-ratio F Fraction of total steps for LR warmup (default: 0.0)\n"); + printf(" --warmup-steps N Explicit warmup steps (overrides warmup-ratio)\n"); printf("\nCheckpointing Options:\n"); printf(" --checkpoint-save-steps N Save checkpoint every N training steps (default: 100)\n"); printf(" --checkpoint-save-dir PATH Directory for checkpoints (default: ./checkpoints)\n"); @@ -472,6 +490,10 @@ struct finetune_params { float lr_min = 0.0f; float weight_decay = 0.01f; std::string lr_scheduler = "constant"; + float warmup_ratio = 0.0f; + int64_t warmup_steps = 0; + bool warmup_ratio_set = false; + bool warmup_steps_set = false; int32_t checkpoint_save_steps = 100; std::string checkpoint_save_dir = "./checkpoints"; @@ -541,6 +563,16 @@ static bool parse_finetune_args(int& argc, char** argv, finetune_params& ft_para ft_params.lr_min = std::atof(argv[i + 1]); remove_arg_pair(i); i--; + } else if (strcmp(argv[i], "--warmup-ratio") == 0 && i + 1 < argc) { + ft_params.warmup_ratio = std::atof(argv[i + 1]); + ft_params.warmup_ratio_set = true; + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--warmup-steps") == 0 && i + 1 < argc) { + ft_params.warmup_steps = std::atoll(argv[i + 1]); + ft_params.warmup_steps_set = true; + remove_arg_pair(i); + i--; } else if (strcmp(argv[i], "--checkpoint-save-steps") == 0 && i + 1 < argc) { ft_params.checkpoint_save_steps = std::atoi(argv[i + 1]); remove_arg_pair(i); @@ -761,6 +793,20 @@ int main(int argc, char ** argv) { lr_scheduler.lr_min = (scheduler_type == lora_lr_schedule_type::CONSTANT) ? ft_params.learning_rate : ft_params.lr_min; lr_scheduler.weight_decay = ft_params.weight_decay; lr_scheduler.total_steps = std::max(1, static_cast(ft_params.num_epochs) * training_batches_per_epoch); + if (ft_params.warmup_steps_set) { + lr_scheduler.warmup_steps = std::min(ft_params.warmup_steps, lr_scheduler.total_steps); + } else if (ft_params.warmup_ratio_set) { + const double warmup_from_ratio = static_cast(lr_scheduler.total_steps) * static_cast(ft_params.warmup_ratio); + lr_scheduler.warmup_steps = std::min(static_cast(warmup_from_ratio), lr_scheduler.total_steps); + } else { + lr_scheduler.warmup_steps = 0; + } + lr_scheduler.warmup_steps = std::max(lr_scheduler.warmup_steps, 0); + if (lr_scheduler.total_steps > 0) { + lr_scheduler.warmup_ratio = static_cast(lr_scheduler.warmup_steps) / static_cast(lr_scheduler.total_steps); + } else { + lr_scheduler.warmup_ratio = 0.0f; + } lr_scheduler.current_step = 0; lr_scheduler.last_lr = lora_scheduler_lr_for_step(lr_scheduler, lr_scheduler.current_step); @@ -774,6 +820,12 @@ int main(int argc, char ** argv) { } else if (lr_scheduler.schedule == lora_lr_schedule_type::LINEAR) { LOG_INF("Linear scheduler: lr-min=%.4e\n", lr_scheduler.lr_min); } + if (lr_scheduler.warmup_steps > 0) { + LOG_INF("Warmup: steps=%lld ratio=%.4f\n", (long long) lr_scheduler.warmup_steps, lr_scheduler.warmup_ratio); + } else if (ft_params.warmup_ratio_set) { + LOG_WRN("Warmup ratio %.4f produced 0 warmup steps (total_steps=%lld); no warmup applied\n", + ft_params.warmup_ratio, (long long) lr_scheduler.total_steps); + } int start_epoch = 0; int64_t start_step = 0; From 082e7a0f8b0475c8849a261c650f0a0025ffda29 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Tue, 28 Oct 2025 15:09:44 +0100 Subject: [PATCH 4/5] lora: Fix lr assertion on step 0 --- examples/training/finetune-lora.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/training/finetune-lora.cpp b/examples/training/finetune-lora.cpp index d42137776c0..f8ea824504d 100644 --- a/examples/training/finetune-lora.cpp +++ b/examples/training/finetune-lora.cpp @@ -131,11 +131,11 @@ static struct ggml_opt_optimizer_params lora_scheduler_get_optimizer_params(void return params; } - const float lr = lora_scheduler_lr_for_step(*scheduler, scheduler->current_step); + const float lr = lora_scheduler_lr_for_step(*scheduler, scheduler->current_step+1); scheduler->last_lr = lr; params.adamw.alpha = lr; - params.adamw.wd = scheduler->weight_decay; + params.adamw.wd = scheduler->weight_decay; params.sgd.alpha = lr; params.sgd.wd = scheduler->weight_decay; @@ -444,7 +444,7 @@ static void checkpoint_progress_callback( } if (cb_data->lr_scheduler) { - cb_data->learning_rate = lora_scheduler_lr_for_step(*cb_data->lr_scheduler, cb_data->lr_scheduler->current_step); + cb_data->learning_rate = lora_scheduler_lr_for_step(*cb_data->lr_scheduler, cb_data->lr_scheduler->current_step+1); } if (cb_data->checkpoint_save_steps <= 0) { @@ -919,7 +919,7 @@ int main(int argc, char ** argv) { for (int epoch = start_epoch; epoch < ft_params.num_epochs; ++epoch) { if (cb_data.lr_scheduler) { - cb_data.learning_rate = lora_scheduler_lr_for_step(*cb_data.lr_scheduler, cb_data.lr_scheduler->current_step); + cb_data.learning_rate = lora_scheduler_lr_for_step(*cb_data.lr_scheduler, cb_data.lr_scheduler->current_step+1); } LOG_INF("Starting epoch %d (step %lld, lr=%.4e)\n", epoch, (long long)cb_data.global_step, cb_data.learning_rate); cb_data.current_epoch = epoch; From 661890cdd01ab48c8fbe81da497d688b0a1ec63f Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Sat, 1 Nov 2025 07:28:07 -0400 Subject: [PATCH 5/5] lora: Fix training start from step 2 --- src/llama-context.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 24637c3d16e..10f9a665e16 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2290,7 +2290,7 @@ void llama_context::opt_epoch( std::vector labels_sparse(n_ctx); std::vector masks_sparse(n_ctx); - int64_t idata = (resume_from_batch >= 0) ? resume_from_batch + 1 : 0; + int64_t idata = (resume_from_batch > 0) ? resume_from_batch + 1 : 0; int64_t t_loop_start = ggml_time_us(); int64_t ndata_in_loop = idata_split*ubatch_per_ctx;