Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions examples/gralora_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,34 @@ With respect to your standard PEFT training procedure with LoRA, simply swap you

```python
import torch
from peft import GraloraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
from peft import GraloraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig

model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
gralora_config = GraloraConfig()
peft_model = get_peft_model(model, gralora_config)
trainer = transformers.Trainer(
model=peft_model,

trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=2048,
tokenizer=tokenizer,
processing_class=tokenizer,
peft_config=gralora_config,
args=SFTConfig(
output_dir="./gralora-llama-7b",
max_length=2048,
dataset_text_field="text",
per_device_train_batch_size=2,
),
)
trainer.train()
peft_model.save_pretrained("gralora-llama-3-8b")
trainer.model.save_pretrained("gralora-llama-7b")
```

Run the finetuning script simply by running:
```python
```sh
python examples/gralora_finetuning/gralora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco
```

Expand All @@ -51,12 +57,9 @@ model = AutoModelForCausalLM.from_pretrained(
peft_model = PeftModel.from_pretrained(model, "gralora-llama-3-8b")
```

## Additonal Notes
## Additional Notes
While `gralora_k` is set to 2 for default, you can increase this value to create more fine-grained adapters. `gralora_k` of 4 is recommended when the total rank (`r + hybrid_r`) is 64 or higher.




## Citation
```
@misc{jung2025graloragranularlowrankadaptation,
Expand Down
Loading