Skip to content

Commit 925ad72

Browse files
ADD: documentations, examples, and test code for GraLoRA method
1 parent dec25f5 commit 925ad72

File tree

6 files changed

+340
-0
lines changed

6 files changed

+340
-0
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# GraLoRA
2+
3+
[**Granular Low-Rank Adaptation (GraLoRA)**](https://huggingface.co/papers/2505.20355) is a PEFT method designed to enhance the **expressivity** of low-rank adaptation while improving **robustness to outlier** activations, based on insights from well-known issues in quantization.
4+
5+
![GraLoRA Overview](https://github.com/SqueezeBits/GraLoRA/raw/main/figure/gralora_overview.png)
6+
7+
Unlike standard LoRA, which applies a single low-rank adapter across the entire feature space, GraLoRA introduces a structured and fine-grained adaptation scheme. It divides the adaptation space into a grid of $𝑘^2$ smaller, independent adapter pairs, each responsible for a localized subset of the input and output dimensions. As a result, each adapter operates on a subspace that is $k$ times smaller in both dimensions than the original LoRA adapter.
8+
9+
This granular decomposition enables spatially localized and context-aware updates, effectively increasing representational capacity without additional parameters or computational cost. By isolating the influence of extreme activations within smaller subspaces, GraLoRA mitigates gradient distortion and preserves inter-channel balance during adaptation.
10+
11+
---
12+
13+
The abstract from the paper is:
14+
15+
*Low-Rank Adaptation (LoRA) is a popular method for parameter-efficient fine-
16+
tuning (PEFT) of generative models, valued for its simplicity and effectiveness.
17+
Despite recent enhancements, LoRA still suffers from a fundamental limitation:
18+
overfitting when the bottleneck is widened. It performs best at ranks 32–64, yet its
19+
accuracy stagnates or declines at higher ranks, still falling short of full fine-tuning
20+
(FFT) performance. We identify the root cause as LoRA’s structural bottleneck,
21+
which introduces gradient entanglement to the unrelated input channels and distorts
22+
gradient propagation. To address this, we introduce a novel structure, Granular
23+
Low-Rank Adaptation (GraLoRA) that partitions weight matrices into sub-blocks,
24+
each with its own low-rank adapter. With negligible computational or storage cost,
25+
GraLoRA overcomes LoRA’s limitations, effectively increases the representational
26+
capacity, and more closely approximates FFT behavior. Experiments on code
27+
generation, commonsense reasoning, mathematical reasoning, general language
28+
understanding, and image generation benchmarks show that GraLoRA consistently
29+
outperforms LoRA and other baselines, achieving up to +8.5% absolute gain in
30+
Pass@1 on HumanEval+. These improvements hold across model sizes and rank
31+
settings, making GraLoRA a scalable and robust solution for PEFT.*
32+
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# GraLoRA: Granular Low-Rank Adaptation
2+
3+
![GraLoRA Overview](https://github.com/SqueezeBits/GraLoRA/raw/main/figure/gralora_overview.png)
4+
5+
## Introduction
6+
[**Granular Low-Rank Adaptation (GraLoRA)**](https://huggingface.co/papers/2505.20355) is a PEFT method designed to enhance the **expressivity** of low-rank adaptation while improving **robustness to outlier** activations, based on insights from well-known issues in quantization.
7+
8+
GraLoRA introduces a structured and fine-grained adaptation scheme. It divides the adaptation space into a grid of $𝑘^2$ smaller, independent adapter pairs, each responsible for a localized subset of the input and output dimensions.
9+
10+
## Quick start
11+
12+
With respect to your standard PEFT training procedure with LoRA, simply swap your `LoraConfig` for a `GraloraConfig`.
13+
14+
```python
15+
import torch
16+
from peft import GraloraConfig, get_peft_model
17+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
18+
from datasets import load_dataset
19+
20+
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="auto")
21+
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
22+
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
23+
gralora_config = GraloraConfig()
24+
peft_model = get_peft_model(model, gralora_config)
25+
trainer = transformers.Trainer(
26+
model=peft_model,
27+
train_dataset=dataset,
28+
dataset_text_field="text",
29+
max_seq_length=2048,
30+
tokenizer=tokenizer,
31+
)
32+
trainer.train()
33+
peft_model.save_pretrained("gralora-llama-3-8b")
34+
```
35+
36+
Run the finetuning script simply by running:
37+
```python
38+
python examples/gralora_finetuning/gralora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco
39+
```
40+
41+
## Use the model on 🤗
42+
You can load and use the model as any other 🤗 models.
43+
```python
44+
import torch
45+
from peft import PeftModel
46+
from transformers import AutoModelForCausalLM
47+
48+
model = AutoModelForCausalLM.from_pretrained(
49+
"meta-llama/Meta-Llama-3-8B", dtype=torch.bfloat16, device_map="auto"
50+
)
51+
peft_model = PeftModel.from_pretrained(model, "gralora-llama-3-8b")
52+
```
53+
54+
## Additonal Notes
55+
While `gralora_k` is set to 2 for default, you can increase this value to create more fine-grained adapters. `gralora_k` of 4 is recommended when the total rank (`r + hybrid_r`) is 64 or higher.
56+
57+
58+
59+
60+
## Citation
61+
```
62+
@misc{jung2025graloragranularlowrankadaptation,
63+
title={GraLoRA: Granular Low-Rank Adaptation for Parameter-Efficient Fine-Tuning},
64+
author={Yeonjoon Jung and Daehyun Ahn and Hyungjun Kim and Taesu Kim and Eunhyeok Park},
65+
year={2025},
66+
eprint={2505.20355},
67+
archivePrefix={arXiv},
68+
primaryClass={cs.LG},
69+
url={https://arxiv.org/abs/2505.20355},
70+
}
71+
```
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# This script is based on examples/dora_finetuning/dora_finetuning.py
2+
import os
3+
4+
import torch
5+
from datasets import load_dataset
6+
from transformers import (
7+
AutoModelForCausalLM,
8+
AutoTokenizer,
9+
BitsAndBytesConfig,
10+
DataCollatorForLanguageModeling,
11+
Trainer,
12+
TrainingArguments,
13+
)
14+
15+
from peft import GraloraConfig, get_peft_model, prepare_model_for_kbit_training
16+
17+
18+
def train_model(
19+
base_model: str,
20+
data_path: str,
21+
output_dir: str,
22+
batch_size: int,
23+
num_epochs: int,
24+
learning_rate: float,
25+
cutoff_len: int,
26+
val_set_size: int,
27+
quantize: bool,
28+
eval_step: int,
29+
save_step: int,
30+
device: str,
31+
gralora_r: int,
32+
gralora_alpha: int,
33+
gralora_dropout: float,
34+
gralora_target_modules: str,
35+
gralora_k: int,
36+
hybrid_r: int,
37+
hub_model_id: str,
38+
push_to_hub: bool,
39+
):
40+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
41+
hf_token = os.getenv("HF_TOKEN")
42+
43+
# Setup device
44+
if device == "auto":
45+
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
46+
else:
47+
device = torch.device(device)
48+
print(f"Using device: {device}")
49+
50+
# load tokenizer
51+
tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token)
52+
53+
# Quantized GraLoRA: IF YOU WANNA QUANTIZE THE MODEL
54+
if quantize:
55+
if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) or torch.xpu.is_available():
56+
bnb_4bit_compute_dtype = torch.bfloat16
57+
else:
58+
bnb_4bit_compute_dtype = torch.float16
59+
model = AutoModelForCausalLM.from_pretrained(
60+
base_model,
61+
token=hf_token,
62+
quantization_config=BitsAndBytesConfig(
63+
load_in_4bit=True,
64+
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
65+
bnb_4bit_use_double_quant=True,
66+
bnb_4bit_quant_type="nf4",
67+
),
68+
)
69+
# setup for quantized training
70+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
71+
else:
72+
model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token)
73+
# GraLoRA config for the PEFT model
74+
gralora_config = GraloraConfig(
75+
r=gralora_r, # Rank of matrix
76+
gralora_alpha=gralora_alpha,
77+
target_modules=(
78+
gralora_target_modules.split(",")
79+
if gralora_target_modules
80+
else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
81+
),
82+
gralora_dropout=gralora_dropout,
83+
gralora_k=gralora_k,
84+
hybrid_r=hybrid_r,
85+
bias="none",
86+
)
87+
88+
# get the peft model with GraLoRA config
89+
model = get_peft_model(model, gralora_config)
90+
91+
model.to(device) # MODEL TO GPU/CUDA
92+
tokenizer.pad_token = tokenizer.eos_token
93+
94+
# Load the dataset
95+
dataset = load_dataset(data_path)
96+
97+
def tokenize_function(examples):
98+
inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len)
99+
inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task
100+
return inputs
101+
102+
# Tokenize the dataset and prepare for training
103+
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
104+
105+
# Data collator to dynamically pad the batched examples
106+
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
107+
108+
# Define training arguments
109+
training_args = TrainingArguments(
110+
output_dir=output_dir,
111+
num_train_epochs=num_epochs,
112+
per_device_train_batch_size=batch_size,
113+
per_device_eval_batch_size=batch_size,
114+
warmup_steps=100,
115+
weight_decay=0.01,
116+
logging_dir="./logs",
117+
logging_steps=eval_step,
118+
save_steps=save_step,
119+
save_total_limit=2,
120+
push_to_hub=push_to_hub,
121+
hub_model_id=hub_model_id,
122+
gradient_accumulation_steps=16,
123+
fp16=True,
124+
learning_rate=learning_rate,
125+
hub_token=hf_token,
126+
)
127+
128+
# Clear device cache to free memory
129+
if torch.cuda.is_available():
130+
torch.cuda.empty_cache()
131+
elif torch.xpu.is_available():
132+
torch.xpu.empty_cache()
133+
134+
# Initialize the Trainer
135+
trainer = Trainer(
136+
model=model,
137+
args=training_args,
138+
train_dataset=tokenized_datasets["train"],
139+
eval_dataset=tokenized_datasets["test"],
140+
data_collator=data_collator,
141+
)
142+
143+
# Start model training
144+
trainer.train()
145+
146+
# Save and push the trained model and tokenizer
147+
if push_to_hub:
148+
# Push the main model to the hub
149+
trainer.push_to_hub(commit_message="Fine-tuned model")
150+
151+
# Save the model and tokenizer locally
152+
model.save_pretrained(output_dir)
153+
tokenizer.save_pretrained(output_dir)
154+
155+
156+
if __name__ == "__main__":
157+
import argparse
158+
159+
parser = argparse.ArgumentParser(description="Fine-tune LLaMA with GraLoRA and PEFT")
160+
parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name")
161+
parser.add_argument(
162+
"--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name"
163+
)
164+
parser.add_argument(
165+
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model"
166+
)
167+
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
168+
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
169+
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
170+
parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization")
171+
parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size")
172+
parser.add_argument("--quantize", action="store_true", help="Use quantization")
173+
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
174+
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
175+
parser.add_argument("--device", type=str, default="auto", help="Device to use for training")
176+
parser.add_argument("--gralora_r", type=int, default=8, help="LoRA rank")
177+
parser.add_argument("--gralora_alpha", type=int, default=16, help="LoRA alpha")
178+
parser.add_argument("--gralora_dropout", type=float, default=0.05, help="LoRA dropout rate")
179+
parser.add_argument(
180+
"--gralora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA"
181+
)
182+
parser.add_argument("--gralora_k", type=int, default=2, help="GraLoRA k")
183+
parser.add_argument("--hybrid_r", type=int, default=0, help="Hybrid rank")
184+
parser.add_argument(
185+
"--hub_model_id",
186+
type=str,
187+
default="path/to/repo",
188+
help="Repository name to push the model on the Hugging Face Hub",
189+
)
190+
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub")
191+
args = parser.parse_args()
192+
train_model(
193+
base_model=args.base_model,
194+
data_path=args.data_path,
195+
output_dir=args.output_dir,
196+
batch_size=args.batch_size,
197+
num_epochs=args.num_epochs,
198+
learning_rate=args.learning_rate,
199+
cutoff_len=args.cutoff_len,
200+
val_set_size=args.val_set_size,
201+
quantize=args.quantize,
202+
eval_step=args.eval_step,
203+
save_step=args.save_step,
204+
device=args.device,
205+
gralora_r=args.gralora_r,
206+
gralora_alpha=args.gralora_alpha,
207+
gralora_dropout=args.gralora_dropout,
208+
gralora_target_modules=args.gralora_target_modules,
209+
gralora_k=args.gralora_k,
210+
hybrid_r=args.hybrid_r,
211+
hub_model_id=args.hub_model_id,
212+
push_to_hub=args.push_to_hub,
213+
)

tests/test_encoder_decoder_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
C3AConfig,
2525
DeloraConfig,
2626
FourierFTConfig,
27+
GraloraConfig,
2728
HRAConfig,
2829
IA3Config,
2930
LoraConfig,
@@ -100,6 +101,13 @@
100101
"task_type": "SEQ_2_SEQ_LM",
101102
},
102103
),
104+
(
105+
GraloraConfig,
106+
{
107+
"target_modules": None,
108+
"task_type": "SEQ_2_SEQ_LM",
109+
},
110+
),
103111
(
104112
HRAConfig,
105113
{

tests/test_feature_extraction_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
C3AConfig,
2323
DeloraConfig,
2424
FourierFTConfig,
25+
GraloraConfig,
2526
HRAConfig,
2627
IA3Config,
2728
LoraConfig,
@@ -98,6 +99,13 @@
9899
"target_modules": None,
99100
},
100101
),
102+
(
103+
GraloraConfig,
104+
{
105+
"task_type": "FEATURE_EXTRACTION",
106+
"target_modules": None,
107+
},
108+
),
101109
(
102110
HRAConfig,
103111
{

tests/test_seq_classifier.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
C3AConfig,
2323
DeloraConfig,
2424
FourierFTConfig,
25+
GraloraConfig,
2526
HRAConfig,
2627
IA3Config,
2728
LoraConfig,
@@ -99,6 +100,13 @@
99100
"target_modules": None,
100101
},
101102
),
103+
(
104+
GraloraConfig,
105+
{
106+
"task_type": "SEQ_CLS",
107+
"target_modules": None,
108+
},
109+
),
102110
(
103111
HRAConfig,
104112
{

0 commit comments

Comments
 (0)