Skip to content

Commit 3f69d8f

Browse files
REFACTOR: integrate GraLoRA tests into existing test files
1 parent 925ad72 commit 3f69d8f

File tree

7 files changed

+164
-1086
lines changed

7 files changed

+164
-1086
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@
116116
title: VeRA
117117
- local: package_reference/fourierft
118118
title: FourierFT
119+
- local: package_reference/gralora
120+
title: GraLoRA
119121
- local: package_reference/vblora
120122
title: VB-LoRA
121123
- local: package_reference/hra

src/peft/tuners/gralora/config.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,57 @@
2121

2222
@dataclass
2323
class GraloraConfig(PeftConfig):
24+
"""
25+
This is the configuration class to store the configuration of a [`GraloraModel`].
26+
27+
Args:
28+
r (`int`):
29+
GraLoRA attention dimension determines the rank of the GraLoRA adapter.
30+
The total parameter count of the GraLoRA adapter is same as LoRA with same rank r, while the expressivitiy is multiplied by gralora_k.
31+
hybrid_r (`int`):
32+
Hybrid GraLoRA rank determines the rank allocated to vanilla LoRA method when using Hybrid GraLoRA method.
33+
Hybrid GraLoRA, a combination of GraLoRA and vanilla LoRA, becomes available when hybrid_r > 0.
34+
The parameter count of the GraLoRA adapter is r + hybrid_r.
35+
target_modules (`Union[List[str], str]`):
36+
List of module names or regex expression of the module names to replace with GraLoRA. "
37+
For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
38+
This can also be a wildcard 'all-linear' which matches all linear/Conv1D "
39+
"(if the model is a PreTrainedModel, the output layer excluded). "
40+
If not specified, modules will be chosen according to the model architecture, If the architecture is "
41+
not known, an error will be raised -- in this case, you should specify the target modules manually. "
42+
To avoid targeting any modules (because you want to apply `target_parameters`), set "
43+
`target_modules=[]`.
44+
gralora_alpha (`int`): GraLoRA alpha.
45+
GraLoRA alpha is the scaling factor for the GraLoRA adapter.
46+
Scale becomes gralora_alpha / (r + hybrid_r).
47+
gralora_dropout (`float`):
48+
GraLoRA dropout is the dropout probability for the GraLoRA adapter.
49+
It is used to prevent overfitting and improve the generalization of the GraLoRA adapter.
50+
gralora_k (`int`):
51+
GraLoRA k determines the number of subblocks in the GraLoRA adapter.
52+
The rank r must be divisible by gralora_k for the GraLoRA adapter to be valid.
53+
The total parameter count is preserved regardles of gralora_k.
54+
The entire rank of the GraLoRA adapter is increased by gralora_k, while the rank of each subblock is reduced by gralora_k.
55+
gralora_k=2 is recommended for rank 32 or lower, and gralora_k=4 is recommended for rank 64 or higher.
56+
fan_in_fan_out (`bool`):
57+
Set this to True if the layer to replace stores weight like (fan_in, fan_out).
58+
For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.
59+
bias (`str`):
60+
Bias type for gralora. Can be 'none', 'all' or 'gralora_only'.
61+
If 'all' or 'gralora_only', the corresponding biases will be updated during training.
62+
Be aware that this means that, even when disabling the adapters, the model will not produce the same output as the base model would have without adaptation.
63+
init_weights (`bool`):
64+
Whether to initialize the weights of the GraLoRA layers with their default initialization.
65+
Don't change this setting, except if you know exactly what you're doing.
66+
layers_to_transform (`Union[List[int], int]`):
67+
The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list.
68+
If a single integer is passed, PEFT will transform only the layer at this index.
69+
This only works when target_modules is a list of str.
70+
layers_pattern (`Optional[Union[List[str], str]]`):
71+
The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern.
72+
This only works when target_modules is a list of str. This should target the `nn.ModuleList` of the model, which is often called `'layers'` or `'h'`.
73+
"""
74+
2475
r: int = field(
2576
default=32,
2677
metadata={
@@ -44,18 +95,23 @@ class GraloraConfig(PeftConfig):
4495
default=None,
4596
metadata={
4697
"help": (
47-
"List of module names or regex expression of the module names to replace with gralora. "
98+
"List of module names or regex expression of the module names to replace with LoRA. "
4899
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
49-
"Only linear layers are supported."
100+
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D "
101+
"(if the model is a PreTrainedModel, the output layer excluded). "
102+
"If not specified, modules will be chosen according to the model architecture, If the architecture is "
103+
"not known, an error will be raised -- in this case, you should specify the target modules manually. "
104+
"To avoid targeting any modules (because you want to apply `target_parameters`), set "
105+
"`target_modules=[]`."
50106
)
51107
},
52108
)
53109
gralora_alpha: int = field(
54110
default=64,
55111
metadata={
56112
"help": (
57-
"gralora alpha is the scaling factor for the GraLoRA adapter."
58-
"Scale becomes gralora_alpha / (r + hybrid_r)."
113+
"gralora alpha is the scaling factor for the GraLoRA adapter. "
114+
"Scale becomes gralora_alpha / (r + hybrid_r). "
59115
)
60116
},
61117
)
@@ -64,8 +120,11 @@ class GraloraConfig(PeftConfig):
64120
default=2,
65121
metadata={
66122
"help": (
67-
"gralora_k determines the number of subblocks in the GraLoRA adapter."
68-
"The total parameter count is preserved regardles of gralora_k, while the expressivitiy is multiplied by gralora_k."
123+
"gralora_k determines the number of subblocks in the GraLoRA adapter. "
124+
"The rank r must be divisible by gralora_k for the GraLoRA adapter to be valid. "
125+
"The total parameter count is preserved regardles of gralora_k. "
126+
"The entire rank of the GraLoRA adapter is increased by gralora_k, while the rank of each subblock is reduced by gralora_k. "
127+
"gralora_k=2 is recommended for rank 32 or lower, and gralora_k=4 is recommended for rank 64 or higher. "
69128
)
70129
},
71130
)
@@ -99,18 +158,19 @@ class GraloraConfig(PeftConfig):
99158
default=None,
100159
metadata={
101160
"help": (
102-
"The layer indexes to transform, is this argument is specified, PEFT will transform only the layers"
103-
" indexes that are specified inside this list. If a single integer is passed, PEFT will transform only"
104-
" the layer at this index."
161+
"The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. "
162+
"If a single integer is passed, PEFT will transform only the layer at this index. "
163+
"This only works when target_modules is a list of str."
105164
)
106165
},
107166
)
108167
layers_pattern: Optional[str] = field(
109168
default=None,
110169
metadata={
111170
"help": (
112-
"The layer pattern name, used only if `layers_to_transform` is different to None and if the layer"
113-
" pattern is not in the common layers pattern."
171+
"The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern. "
172+
"This only works when target_modules is a list of str. This should target the `nn.ModuleList` of the "
173+
"model, which is often called `'layers'` or `'h'`."
114174
)
115175
},
116176
)

src/peft/tuners/gralora/layer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
271271
in_features = self.in_features
272272
out_features = self.out_features
273273
gralora_rank = r
274-
if in_features % gralora_k != 0:
275-
raise ValueError(f"in_features should be divisible by gralora_k, but got {in_features} and {gralora_k}")
276-
elif out_features % gralora_k != 0:
277-
raise ValueError(f"out_features should be divisible by gralora_k, but got {out_features} and {gralora_k}")
278-
elif gralora_rank % gralora_k != 0:
279-
raise ValueError(f"rank should be divisible by gralora_k, but got {gralora_rank} and {gralora_k}")
280274
subblock_gralora_rank = gralora_rank // gralora_k
281275

282276
# scatter gralora_A to get the scattered weight matrix

src/peft/tuners/gralora/model.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,15 @@
1515
from __future__ import annotations
1616

1717
import warnings
18-
from dataclasses import asdict
19-
from enum import Enum
20-
from typing import Optional
2118

2219
import torch
23-
import torch.nn as nn
24-
from tqdm import tqdm
2520
from transformers.pytorch_utils import Conv1D
2621

27-
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
22+
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
2823
from peft.utils import (
2924
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
30-
ModulesToSaveWrapper,
31-
_get_submodules,
3225
)
3326

34-
from .config import GraloraConfig
3527
from .layer import GraloraLayer, Linear
3628

3729

tests/test_custom_models.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,18 @@
680680
GraloraConfig,
681681
{"target_modules": ["lin0"], "modules_to_save": ["lin1"]},
682682
),
683+
(
684+
"Vanilla MLP 6 GraLoRA",
685+
"MLP",
686+
GraloraConfig,
687+
{"target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"]},
688+
),
689+
(
690+
"Vanilla MLP 7 Hybrid GraLoRA",
691+
"MLP",
692+
GraloraConfig,
693+
{"target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"], "hybrid_r": 4},
694+
),
683695
(
684696
"Embedding + transformers Conv1D 1 GraLoRA",
685697
"EmbConv1D",
@@ -3124,12 +3136,12 @@ def test_add_weighted_adapter_subtraction_with_negative_weights(self):
31243136
cancelled_B = module.lora_B["cancelled"].weight.data
31253137

31263138
# The weights should be approximately zero (they cancel out)
3127-
assert torch.allclose(
3128-
cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5
3129-
), f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}"
3130-
assert torch.allclose(
3131-
cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5
3132-
), f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}"
3139+
assert torch.allclose(cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5), (
3140+
f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}"
3141+
)
3142+
assert torch.allclose(cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5), (
3143+
f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}"
3144+
)
31333145

31343146
def test_add_weighted_adapter_negative_weight_with_different_scaling(self):
31353147
# Test negative weights with different scaling factors (lora_alpha)
@@ -3440,6 +3452,24 @@ def test_dora_save_and_load_remapping(self):
34403452
for k in state_dict:
34413453
assert torch.allclose(state_dict[k], state_dict_loaded[k])
34423454

3455+
def test_gralora_and_hybrid_gralora_parameter_count(self):
3456+
# Here we test the parameter count of GraLoRA is preserved
3457+
# when rank r + hybrid_r is the same regardless of the value of gralora_k.
3458+
model1 = MLP()
3459+
config1 = GraloraConfig(target_modules=["lin0"], r=12, gralora_k=2, hybrid_r=0)
3460+
model1 = get_peft_model(model1, config1)
3461+
model2 = MLP()
3462+
config2 = GraloraConfig(target_modules=["lin0"], r=10, gralora_k=2, hybrid_r=2)
3463+
model2 = get_peft_model(model2, config2)
3464+
model3 = MLP()
3465+
config3 = GraloraConfig(target_modules=["lin0"], r=10, gralora_k=5, hybrid_r=2)
3466+
model3 = get_peft_model(model3, config3)
3467+
trainable_params1, all_params1 = model1.get_nb_trainable_parameters()
3468+
trainable_params2, all_params2 = model2.get_nb_trainable_parameters()
3469+
trainable_params3, all_params3 = model3.get_nb_trainable_parameters()
3470+
assert trainable_params1 == trainable_params2 == trainable_params3
3471+
assert all_params1 == all_params2 == all_params3
3472+
34433473
@pytest.mark.parametrize("with_forward_call", [False, True])
34443474
def test_mha_gradients_set_correctly(self, with_forward_call):
34453475
# check for this bug: https://github.com/huggingface/peft/issues/761#issuecomment-1893804738
@@ -3535,9 +3565,9 @@ def test_multirank_2(self):
35353565
if isinstance(module, BaseTunerLayer):
35363566
rank_expected = rank_pattern.get(key, r)
35373567
rank_current = module.lora_A[adapter].weight.shape[0]
3538-
assert (
3539-
rank_current == rank_expected
3540-
), f"Rank {rank_current} is not equal to expected {rank_expected}"
3568+
assert rank_current == rank_expected, (
3569+
f"Rank {rank_current} is not equal to expected {rank_expected}"
3570+
)
35413571

35423572

35433573
class TestLayerRepr:

0 commit comments

Comments
 (0)