Skip to content

Commit a1c944a

Browse files
HaohanTsaoyeonjoon-jung01
authored andcommitted
DOCS: update README for GraLoRA finetuning with correct SFTTrainer integration
1 parent 351877f commit a1c944a

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

examples/gralora_finetuning/README.md

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,34 @@ With respect to your standard PEFT training procedure with LoRA, simply swap you
1313

1414
```python
1515
import torch
16-
from peft import GraloraConfig, get_peft_model
17-
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
16+
from peft import GraloraConfig
17+
from transformers import AutoTokenizer, AutoModelForCausalLM
1818
from datasets import load_dataset
19+
from trl import SFTTrainer, SFTConfig
1920

2021
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="auto")
2122
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
2223
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
2324
gralora_config = GraloraConfig()
24-
peft_model = get_peft_model(model, gralora_config)
25-
trainer = transformers.Trainer(
26-
model=peft_model,
25+
26+
trainer = SFTTrainer(
27+
model=model,
2728
train_dataset=dataset,
28-
dataset_text_field="text",
29-
max_seq_length=2048,
30-
tokenizer=tokenizer,
29+
processing_class=tokenizer,
30+
peft_config=gralora_config,
31+
args=SFTConfig(
32+
output_dir="./gralora-llama-7b",
33+
max_length=2048,
34+
dataset_text_field="text",
35+
per_device_train_batch_size=2,
36+
),
3137
)
3238
trainer.train()
33-
peft_model.save_pretrained("gralora-llama-3-8b")
39+
trainer.model.save_pretrained("gralora-llama-7b")
3440
```
3541

3642
Run the finetuning script simply by running:
37-
```python
43+
```sh
3844
python examples/gralora_finetuning/gralora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco
3945
```
4046

@@ -51,12 +57,9 @@ model = AutoModelForCausalLM.from_pretrained(
5157
peft_model = PeftModel.from_pretrained(model, "gralora-llama-3-8b")
5258
```
5359

54-
## Additonal Notes
60+
## Additional Notes
5561
While `gralora_k` is set to 2 for default, you can increase this value to create more fine-grained adapters. `gralora_k` of 4 is recommended when the total rank (`r + hybrid_r`) is 64 or higher.
5662

57-
58-
59-
6063
## Citation
6164
```
6265
@misc{jung2025graloragranularlowrankadaptation,

0 commit comments

Comments
 (0)