Skip to content

Commit 6dfa24e

Browse files
feat: Add Gralora configuration and basic implementation
1 parent a18ba67 commit 6dfa24e

File tree

7 files changed

+755
-0
lines changed

7 files changed

+755
-0
lines changed

src/peft/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
EvaConfig,
6565
FourierFTConfig,
6666
FourierFTModel,
67+
GraloraConfig,
68+
GraloraModel,
6769
HRAConfig,
6870
HRAModel,
6971
IA3Config,
@@ -163,6 +165,8 @@
163165
"EvaConfig",
164166
"FourierFTConfig",
165167
"FourierFTModel",
168+
"GraloraConfig",
169+
"GraloraModel",
166170
"HRAConfig",
167171
"HRAModel",
168172
"IA3Config",

src/peft/tuners/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .cpt import CPTConfig, CPTEmbedding
2121
from .delora import DeloraConfig, DeloraModel
2222
from .fourierft import FourierFTConfig, FourierFTModel
23+
from .gralora import GraloraConfig, GraloraModel
2324
from .hra import HRAConfig, HRAModel
2425
from .ia3 import IA3Config, IA3Model
2526
from .ln_tuning import LNTuningConfig, LNTuningModel
@@ -74,6 +75,8 @@
7475
"EvaConfig",
7576
"FourierFTConfig",
7677
"FourierFTModel",
78+
"GraloraConfig",
79+
"GraloraModel",
7780
"HRAConfig",
7881
"HRAModel",
7982
"IA3Config",
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2023-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .config import GraloraConfig
16+
from .layer import GraloraLayer
17+
from .model import GraloraModel
18+
19+
20+
__all__ = ["GraloraConfig", "GraloraLayer", "GraloraModel"]

src/peft/tuners/gralora/config.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2023-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass, field
16+
from typing import Optional, Union
17+
18+
from peft.config import PeftConfig
19+
from peft.utils import PeftType
20+
21+
22+
@dataclass
23+
class GraloraConfig(PeftConfig):
24+
r: int = field(default=8, metadata={"help": "gralora attention dimension"})
25+
hybrid_r: int = field(
26+
default=0, metadata={"help": "hybrid_r is the rank allocated to vanilla LoRA method when using Hybrid GraLoRA"}
27+
)
28+
target_modules: Optional[Union[list[str], str]] = field(
29+
default=None,
30+
metadata={
31+
"help": (
32+
"List of module names or regex expression of the module names to replace with gralora."
33+
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
34+
"Only linear layers are supported."
35+
)
36+
},
37+
)
38+
gralora_alpha: int = field(default=8, metadata={"help": "gralora alpha"})
39+
gralora_dropout: float = field(default=0.0, metadata={"help": "gralora dropout"})
40+
gralora_k: int = field(default=2, metadata={"help": "gralora k"})
41+
fan_in_fan_out: bool = field(
42+
default=False,
43+
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
44+
)
45+
bias: str = field(
46+
default="none", metadata={"help": "Bias type for gralora. Can be 'none', 'all' or 'gralora_only'"}
47+
)
48+
modules_to_save: Optional[list[str]] = field(
49+
default=None,
50+
metadata={
51+
"help": (
52+
"List of modules apart from gralora layers to be set as trainable and saved in the final checkpoint. For"
53+
" example, in Sequence Classification or Token Classification tasks, the final layer"
54+
" `classifier/score` are randomly initialized and as such need to be trainable and saved."
55+
)
56+
},
57+
)
58+
layers_to_transform: Optional[Union[list[int], int]] = field(
59+
default=None,
60+
metadata={
61+
"help": (
62+
"The layer indexes to transform, is this argument is specified, PEFT will transform only the layers"
63+
" indexes that are specified inside this list. If a single integer is passed, PEFT will transform only"
64+
" the layer at this index."
65+
)
66+
},
67+
)
68+
layers_pattern: Optional[str] = field(
69+
default=None,
70+
metadata={
71+
"help": (
72+
"The layer pattern name, used only if `layers_to_transform` is different to None and if the layer"
73+
" pattern is not in the common layers pattern."
74+
)
75+
},
76+
)
77+
78+
def __post_init__(self):
79+
self.peft_type = PeftType.GRALORA
80+
self.target_modules = (
81+
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
82+
)

0 commit comments

Comments
 (0)