Skip to content

Commit 0fd18de

Browse files
committed
lora: Update readme; add architecture overview
1 parent 9271a97 commit 0fd18de

File tree

2 files changed

+79
-4
lines changed

2 files changed

+79
-4
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,20 @@ To learn more about model quantization, [read this documentation](tools/quantize
516516
517517
</details>
518518
519+
## LoRA Fine-Tuning
520+
521+
llama.cpp includes native [LoRA](https://arxiv.org/abs/2106.09685) (Low-Rank Adaptation) fine-tuning across CPU, Vulkan, Metal and CUDA backends.
522+
523+
LoRA fine-tuning represents the weight updates with two smaller matrices through low-rank decomposition while keeping the base model frozen. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. This makes training possible on devices with very limited memory, including phones and integrated GPUs. Key capabilities include:
524+
525+
- Train LoRA adapters on any GPU (NVIDIA, AMD, Intel, Apple, Mali, Adreno)
526+
- Full support for FP32/FP16/Q8/Q4 training paths
527+
- Instruction-tuning via assistant-only masked loss
528+
- Checkpointing + resumable training
529+
- Merge LoRA adapters back into a base model `model.gguf`
530+
- Compatible with Qwen3, Gemma, LLaMA, TinyLlama, and other GGUF models
531+
532+
The [Finetuning Guide](examples/training//README.md) has more details.
519533
520534
## Contributing
521535

examples/training/README.md

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ the base model frozen, making it memory-efficient.
6969
- Available: `attn_q`, `attn_k`, `attn_v`, `attn_o`, `ffn_gate`, `ffn_up`, `ffn_down`, `embed`, `output`, `all`
7070
- Default: `attn_q,attn_k,attn_v,attn_o` (attention modules)
7171
- `--output-adapter PATH` - Output adapter filename (default: auto-generated)
72-
- `--assistant-loss-only` - Trains only on assistant tokens
73-
- `--chat-template` - Jinja chat template for chat ML formatting to train on assistant tokens only
7472

7573
#### Checkpointing
7674
- `--checkpoint-save-steps N` - Save checkpoint every N training steps (default: 100)
@@ -83,6 +81,14 @@ the base model frozen, making it memory-efficient.
8381
- `-f FILE` - Training dataset
8482
- `-ngl N` - GPU layers (use 999 for full GPU training)
8583
- `-c N` - Context length (512 recommended for mobile)
84+
- `--assistant-loss-only` - Trains only on assistant tokens
85+
- `--chat-template` - Jinja chat template for chat ML formatting to train on assistant tokens only
86+
- `--learning-rate` - AdamW learning rate (default: 1e-5)
87+
- `--weight-decay` - AdamW weight decay (default: 1e-2)
88+
- `--lr-scheduler` - Learning rate scheduler: constant, cosine, linear (default: constant)
89+
- `--lr-min` - Minimum LR for cosine/linear schedulers (default: 0)
90+
- `--warmup-ratio` - Fraction of total steps for LR warmup (default: 0.0)
91+
- `--warmup-steps` - Explicit warmup steps (overrides warmup-ratio)
8692

8793

8894
### Using Trained Adapters
@@ -96,8 +102,6 @@ After training, you'll get a small adapter file. Use it with the original base m
96102
### Checkpointing
97103

98104
The LoRA fine-tuning supports automatic checkpointing to save and resume training progress:
99-
100-
#### Features
101105
- **Automatic saving**: Model and optimizer state saved every N training steps
102106
- **Complete state**: Includes LoRA weights, optimizer momentum, and training metadata
103107
- **Resume capability**: Continue training from exact step with full optimizer state
@@ -109,6 +113,63 @@ Each checkpoint directory contains:
109113
- `optimizer.gguf` - Optimizer state (momentum, variance, iteration)
110114
- `metadata.json` - Training parameters and step information
111115

116+
### Architecture Overview
117+
118+
This section explains how LoRA fine-tuning is implemented in llama.cpp:
119+
120+
**LoRA Adapter Management (`src/llama-lora-training.cpp`):**
121+
This file manages the complete lifecycle of LoRA adapters:
122+
123+
1. **Adapter Creation (`llama_lora_create_adapter()`):**
124+
- Iterates through all model tensors to find target modules
125+
- Creates low-rank matrix pairs (A, B) for each selected module
126+
- Tensor naming: `blk.{layer}.{module}.lora_a` and `blk.{layer}.{module}.lora_b`
127+
- Dimensions: `A ∈ R^(d×r)`, `B ∈ R^(r×k)` where r is the rank
128+
129+
2. **Weight Initialization (`llama_lora_init_tensor_weights()`):**
130+
- Matrix A: Initialized with Gaussian distribution N(0, init_std)
131+
- Matrix B: Initialized to zeros
132+
- This ensures ΔW = BA starts at zero (no adaptation initially)
133+
- Supports both CPU and GPU tensors via `ggml_backend_tensor_set()`
134+
135+
3. **Buffer Allocation (`llama_lora_allocate_buffers()`):**
136+
- Auto-detects backend from base model (CPU/CUDA/Vulkan)
137+
- Allocates LoRA tensors on same device as model layers
138+
- Uses `ggml_backend_alloc_ctx_tensors_from_buft()` for optimal placement
139+
140+
4. **Module Selection:**
141+
- Scans tensor names for patterns: `attn_q`, `attn_k`, `attn_v`, `attn_output`
142+
- FFN modules: `ffn_gate`, `ffn_up`, `ffn_down`
143+
- Controlled by `target_modules` bitmask (lines 194-211)
144+
145+
5. **Optimizer Integration (`llama_opt_param_filter_lora()`):**
146+
- Filter function for `ggml-opt` to identify trainable parameters
147+
- Returns `true` only for tensors with `.lora_a` or `.lora_b` suffix
148+
- Ensures base model weights are excluded from gradient computation
149+
150+
6. **Checkpointing (`llama_lora_save_checkpoint()`):**
151+
- Creates checkpoint directory structure
152+
- Saves `model.gguf` (LoRA weights via `llama_lora_save_adapter()`)
153+
- Saves `optimizer.gguf` (optimizer state via `ctx->opt_save_state()`)
154+
- Both files required for resuming training
155+
156+
**Forward Pass (`ggml-opt.cpp:ggml_opt_forward()`):**
157+
1. Input batch flows through base model with LoRA injections
158+
2. Loss computation uses only trainable LoRA parameters, we mark only the lora-parameters as trainable with `llama_opt_param_filter_lora`
159+
3. For instruction tuning with `--assistant-loss-only`, loss masking is applied to system/user tokens
160+
161+
**Backward Pass (`ggml-opt.cpp:ggml_opt_backward()`):**
162+
1. Gradients computed only for LoRA adapters (matrices A and B)
163+
2. Base model weights are excluded from gradient computation
164+
3. Memory efficient: only stores gradients for low-rank matrices
165+
166+
**Optimizer State (`ggml-opt.cpp`):**
167+
- Uses AdamW optimizer by default
168+
- Maintains first moment (momentum) and second moment (variance) for each LoRA parameter
169+
- State tensors: `opt_state_m` and `opt_state_v` for each adapter matrix
170+
- Checkpoint format includes full optimizer state for seamless resumption
171+
172+
112173
### Troubleshooting
113174

114175
- **Out of memory**: Reduce context length (`-c 256`), lower rank, or use fewer target modules

0 commit comments

Comments
 (0)