Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 289 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "chat.h"
#include <nlohmann/json.hpp>

#include <algorithm>
#include <cinttypes>
Expand Down Expand Up @@ -1616,3 +1618,290 @@ float lr_opt::get_lr(float epoch) const {
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
return r;
}

ggml_opt_dataset_t common_opt_sft_dataset_init(
struct llama_context * ctx,
const std::string & json_content,
int64_t stride,
const std::string & chat_template_path) {
using json = nlohmann::json;

const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(ctx));
common_chat_templates_ptr chat_templates;
std::string chat_template_source;
if (!chat_template_path.empty()) {
std::ifstream tmpl_file(chat_template_path);
if (!tmpl_file.is_open()) {
LOG_ERR("Warning: Failed to open chat template file: %s\n", chat_template_path.c_str());
} else {
chat_template_source.assign(std::istreambuf_iterator<char>(tmpl_file), std::istreambuf_iterator<char>());
tmpl_file.close();
try {
chat_templates = common_chat_templates_init(llama_get_model(ctx), chat_template_source);
} catch (const std::exception & e) {
LOG_ERR("Warning: Failed to parse chat template '%s': %s\n", chat_template_path.c_str(), e.what());
}
}
}

std::vector<json> conversations;
std::istringstream content_stream(json_content);

std::string line;
while (std::getline(content_stream, line)) {
if (line.empty() || line[0] == '#') continue;
try {
json conv = json::parse(line);
if (conv.contains("messages") && conv["messages"].is_array()) {
conversations.push_back(conv);
}
} catch (const json::exception & e) {
LOG_DBG("Warning: Failed to parse JSON line: %s\n", e.what());
}
}

if (conversations.empty()) {
LOG_ERR("Error: No valid conversations found\n");
return nullptr;
}
LOG_INF("Loaded %zu conversations\n", conversations.size());

const int64_t ne_datapoint = llama_n_ctx(ctx);
if (stride <= 0) stride = ne_datapoint;
if (stride > ne_datapoint) stride = ne_datapoint;

std::vector<std::vector<llama_token>> all_tokenized_data;
std::vector<std::vector<int32_t>> all_assistant_masks;

auto token_count_prefix = [&](const std::string & render, size_t char_count) -> size_t {
std::string prefix = render.substr(0, char_count);
auto t = common_tokenize(ctx, prefix, /*add_special=*/false, /*parse_special=*/true);
return t.size();
};

const std::string START_TAG = "<|im_start|>";
const std::string START_SYS = "<|im_start|>system\n";
const std::string START_USR = "<|im_start|>user\n";
const std::string START_AST = "<|im_start|>assistant\n";
const std::string END_TAG = "<|im_end|>";
const std::string NL = "\n";

for (size_t i = 0; i < conversations.size(); ++i) {
const auto & messages = conversations[i]["messages"];
if (!messages.is_array() || messages.empty()) continue;

std::string render;

if (chat_templates) {
std::vector<common_chat_msg> chat_msgs;
chat_msgs.reserve(messages.size());
for (const auto & msg : messages) {
if (!msg.contains("role") || !msg.contains("content")) {
continue;
}
common_chat_msg chat_msg;
chat_msg.role = msg["role"].get<std::string>();
chat_msg.content = msg["content"].get<std::string>();
chat_msgs.push_back(std::move(chat_msg));
}

if (!chat_msgs.empty()) {
common_chat_templates_inputs inputs;
inputs.messages = std::move(chat_msgs);
inputs.add_generation_prompt = false;
inputs.use_jinja = true;
try {
render = common_chat_templates_apply(chat_templates.get(), inputs).prompt;

size_t last_im_end = render.rfind("<|im_end|>");
if (last_im_end != std::string::npos) {
size_t end_pos = last_im_end + 10; // length of "<|im_end|>"
// Remove any trailing whitespace/newlines after the final <|im_end|>
while (end_pos < render.size() && (render[end_pos] == '\n' || render[end_pos] == '\r' || render[end_pos] == ' ')) {
end_pos++;
}
if (end_pos < render.size()) {
render = render.substr(0, last_im_end + 10); // Keep only up to </im_end>
}
}
} catch (const std::exception & e) {
LOG_WRN("Warning: chat template rendering failed for conversation %zu: %s. Falling back to default ChatML rendering.\n",
i, e.what());
}
}
}

if (render.empty()) {
render.reserve(4096);
for (const auto & msg : messages) {
if (!msg.contains("role") || !msg.contains("content")) continue;
const std::string role = msg["role"].get<std::string>();
const std::string content = msg["content"].get<std::string>();

if (role == "system") {
render += START_SYS; render += content; render += END_TAG + NL;
} else if (role == "user") {
render += START_USR; render += content; render += END_TAG + NL;
} else if (role == "assistant") {
render += START_AST; render += content; render += END_TAG + NL;
}
}
}

if (render.empty()) {
continue;
}

struct Span { size_t lo, hi; };
std::vector<Span> assistant_spans;

{
size_t from = 0;
while (true) {
size_t open = render.find(START_AST, from);
if (open == std::string::npos) break;

// Include the role token ("assistant") and everything through the closing tag/newlines
size_t lo = open + START_TAG.size();
if (lo > render.size()) {
lo = render.size();
}

size_t close = render.find(END_TAG, open + START_AST.size());
if (close == std::string::npos) {
assistant_spans.push_back({lo, render.size()});
break;
}

size_t hi = close + END_TAG.size();
if (hi <= lo) {
lo = open;
hi = close + END_TAG.size();
}

assistant_spans.push_back({lo, std::min(hi, render.size())});

size_t next_from = hi;
from = next_from;
}
}

if (assistant_spans.empty()) {
LOG_WRN("Conversation %zu has no assistant spans\n", i);
continue;
}

auto tokens_full = common_tokenize(ctx, render, /*add_special=*/false, /*parse_special=*/true);
if (tokens_full.empty()) continue;

std::vector<int32_t> assistant_mask(tokens_full.size(), 0);
size_t assistant_token_count = 0;

for (const auto & sp : assistant_spans) {
size_t t_lo = token_count_prefix(render, sp.lo);
size_t t_hi = token_count_prefix(render, sp.hi);
if (t_lo > tokens_full.size()) t_lo = tokens_full.size();
if (t_hi > tokens_full.size()) t_hi = tokens_full.size();


for (size_t t = t_lo; t < t_hi; ++t) {
assistant_mask[t] = 1;
++assistant_token_count;
}
}

if (assistant_token_count == 0) {
LOG_WRN("Warning: Conversation %zu has zero assistant tokens after masking\n", i);
continue;
}

all_tokenized_data.push_back(tokens_full);
all_assistant_masks.push_back(assistant_mask);
}

if (all_tokenized_data.empty()) {
LOG_ERR("ERROR: No valid training samples generated after processing %zu conversations\n", conversations.size());
return nullptr;
}

std::vector<std::vector<llama_token>> final_samples;
std::vector<std::vector<int32_t>> final_masks;

llama_token pad_token = llama_vocab_pad(vocab);
if (pad_token == LLAMA_TOKEN_NULL) {
pad_token = llama_vocab_eos(vocab);
}

for (size_t i = 0; i < all_tokenized_data.size(); ++i) {
const auto& conv_tokens = all_tokenized_data[i];
const auto& conv_mask = all_assistant_masks[i];

if ((int64_t)conv_tokens.size() > ne_datapoint) {
LOG_WRN("Skipping conversation %zu: too long (%zu tokens > %lld)\n", i, conv_tokens.size(), (long long)ne_datapoint);
continue;
}

size_t conv_assistant_tokens = 0;
for (int32_t mask_val : conv_mask) {
if (mask_val == 1) conv_assistant_tokens++;
}

if (conv_assistant_tokens == 0) {
LOG_WRN("Skipping conversation %zu: no assistant tokens\n", i);
continue;
}

std::vector<llama_token> sample_tokens = conv_tokens;
std::vector<int32_t> sample_mask = conv_mask;

sample_tokens.resize(ne_datapoint, pad_token);
sample_mask.resize(ne_datapoint, 0); // Padding tokens are not trained on

final_samples.push_back(sample_tokens);
final_masks.push_back(sample_mask);
}

all_tokenized_data = std::move(final_samples);
all_assistant_masks = std::move(final_masks);

const int64_t ndata = all_tokenized_data.size();

ggml_opt_dataset_t result = ggml_opt_dataset_init_with_masks(
GGML_TYPE_I32, GGML_TYPE_I32, GGML_TYPE_I32,
/*ne_datapoint=*/ne_datapoint, /*ne_label=*/ne_datapoint, /*ne_mask=*/ne_datapoint,
/*ndata=*/ndata, /*ndata_shard=*/1);

if (result == nullptr) {
return nullptr;
}

int32_t * data = (int32_t *) ggml_opt_dataset_data(result)->data;
int32_t * labels = (int32_t *) ggml_opt_dataset_labels(result)->data;
int32_t * masks = (int32_t *) ggml_opt_dataset_masks(result)->data;

for (int64_t idata = 0; idata < ndata; ++idata) {
const auto & sample_tokens = all_tokenized_data[idata];
const auto & sample_mask = all_assistant_masks[idata];

// inputs
for (int64_t i = 0; i < ne_datapoint; ++i) {
data[idata * ne_datapoint + i] = sample_tokens[i];
}

// labels: Set actual next tokens for ALL positions (masked cross-entropy needs real tokens)
for (int64_t i = 0; i < ne_datapoint - 1; ++i) {
// Always set the actual next token - masking is handled separately
labels[idata * ne_datapoint + i] = sample_tokens[i + 1];
}
labels[idata * ne_datapoint + (ne_datapoint - 1)] = sample_tokens[ne_datapoint - 1]; // last token predicts itself (will be masked)

// masks: indicate which preds should be trained on (shifted by 1 from sample_mask)
// Since we predict token[i+1] from token[i], we train when token[i+1] is assistant
for (int64_t i = 0; i < ne_datapoint - 1; ++i) {
masks[idata * ne_datapoint + i] = (i + 1 < ne_datapoint && sample_mask[i + 1] == 1) ? 1 : 0;
}
masks[idata * ne_datapoint + (ne_datapoint - 1)] = 0;
}

return result;
}
5 changes: 5 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,8 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std

// "adamw" or "sgd" (case insensitive)
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
ggml_opt_dataset_t common_opt_sft_dataset_init(
struct llama_context * ctx,
const std::string & json_content,
int64_t stride,
const std::string & chat_template_path = "");
15 changes: 13 additions & 2 deletions examples/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ the base model frozen, making it memory-efficient.

```sh
# Create new LoRA adapter with default settings (rank=8, alpha=16, attention modules)
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 -fa off

# Custom LoRA parameters(creates new lora adapter and trains it from scratch)
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 \
--lora-rank 16 --lora-alpha 32 --lora-modules "attn_q,attn_k,attn_v,attn_o"
--lora-rank 16 --lora-alpha 32 --lora-modules "attn_q,attn_k,attn_v,attn_o" -fa off

# Fine-tune existing LoRA adapter
./build/bin/llama-finetune-lora -m base_model.gguf -f dataset.txt --lora existing_adapter.gguf \
Expand All @@ -44,8 +44,17 @@ the base model frozen, making it memory-efficient.
# Resume training from checkpoint
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 \
--resume-from "./lora_checkpoints/checkpoint_step_00000150/"
--output-adapter improved_adapter.gguf -ngl 999 -c 512 -b 512 -ub 512 -fa off

# Supervised FineTuning with Assistant only loss
./build/bin/llama-finetune-lora -m model.gguf -f dataset.jsonl -ngl 999 -c 512 -b 512 -ub 512 \
--lora-modules "attn_q,attn_k,attn_v,attn_o" --assistant-loss-only -fa off
```

### SFT(Instruction Fine Tuning) with Assistant Only Loss
- Masks the system and user tokens and only computes loss on assistant tokens
- Requires the dataset to be in json format just like huggingface with `role` and `content` for each role
- Allows users to optionally pass a jinja chat template with `--chat-template chat-ml-template.jinja`

### Parameters

Expand All @@ -60,6 +69,8 @@ the base model frozen, making it memory-efficient.
- Available: `attn_q`, `attn_k`, `attn_v`, `attn_o`, `ffn_gate`, `ffn_up`, `ffn_down`, `embed`, `output`, `all`
- Default: `attn_q,attn_k,attn_v,attn_o` (attention modules)
- `--output-adapter PATH` - Output adapter filename (default: auto-generated)
- `--assistant-loss-only` - Trains only on assistant tokens
- `--chat-template` - Jinja chat template for chat ML formatting to train on assistant tokens only

#### Checkpointing
- `--checkpoint-save-steps N` - Save checkpoint every N training steps (default: 100)
Expand Down
Loading
Loading