diff --git a/examples/gralora_finetuning/README.md b/examples/gralora_finetuning/README.md index a911ab86d5..a0485e41e3 100644 --- a/examples/gralora_finetuning/README.md +++ b/examples/gralora_finetuning/README.md @@ -13,28 +13,34 @@ With respect to your standard PEFT training procedure with LoRA, simply swap you ```python import torch -from peft import GraloraConfig, get_peft_model -from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer +from peft import GraloraConfig +from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset +from trl import SFTTrainer, SFTConfig model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="auto") tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") gralora_config = GraloraConfig() -peft_model = get_peft_model(model, gralora_config) -trainer = transformers.Trainer( - model=peft_model, + +trainer = SFTTrainer( + model=model, train_dataset=dataset, - dataset_text_field="text", - max_seq_length=2048, - tokenizer=tokenizer, + processing_class=tokenizer, + peft_config=gralora_config, + args=SFTConfig( + output_dir="./gralora-llama-7b", + max_length=2048, + dataset_text_field="text", + per_device_train_batch_size=2, + ), ) trainer.train() -peft_model.save_pretrained("gralora-llama-3-8b") +trainer.model.save_pretrained("gralora-llama-7b") ``` Run the finetuning script simply by running: -```python +```sh python examples/gralora_finetuning/gralora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco ``` @@ -51,12 +57,9 @@ model = AutoModelForCausalLM.from_pretrained( peft_model = PeftModel.from_pretrained(model, "gralora-llama-3-8b") ``` -## Additonal Notes +## Additional Notes While `gralora_k` is set to 2 for default, you can increase this value to create more fine-grained adapters. `gralora_k` of 4 is recommended when the total rank (`r + hybrid_r`) is 64 or higher. - - - ## Citation ``` @misc{jung2025graloragranularlowrankadaptation,