Skip to content

Commit bfa1ef7

Browse files
HaohanTsaoyeonjoon-jung01
authored andcommitted
ENH Support merge/unmerge in GraLoRA functionality; support init_weights parameter for flexible initialization
1 parent 6dfa24e commit bfa1ef7

File tree

4 files changed

+233
-46
lines changed

4 files changed

+233
-46
lines changed

src/peft/tuners/gralora/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-present the HuggingFace Inc. team.
1+
# Copyright 2025-present the HuggingFace Inc. team.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,9 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from peft.utils import register_peft_method
16+
1517
from .config import GraloraConfig
1618
from .layer import GraloraLayer
1719
from .model import GraloraModel
1820

1921

2022
__all__ = ["GraloraConfig", "GraloraLayer", "GraloraModel"]
23+
24+
register_peft_method(name="gralora", config_cls=GraloraConfig, model_cls=GraloraModel)

src/peft/tuners/gralora/config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-present the HuggingFace Inc. team.
1+
# Copyright 2025-present the HuggingFace Inc. team.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -55,6 +55,15 @@ class GraloraConfig(PeftConfig):
5555
)
5656
},
5757
)
58+
init_weights: bool = field(
59+
default=True,
60+
metadata={
61+
"help": (
62+
"Whether to initialize the weights of the GraLoRA layers with their default initialization. "
63+
"Don't change this setting, except if you know exactly what you're doing."
64+
)
65+
},
66+
)
5867
layers_to_transform: Optional[Union[list[int], int]] = field(
5968
default=None,
6069
metadata={
@@ -76,6 +85,7 @@ class GraloraConfig(PeftConfig):
7685
)
7786

7887
def __post_init__(self):
88+
super().__post_init__()
7989
self.peft_type = PeftType.GRALORA
8090
self.target_modules = (
8191
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules

src/peft/tuners/gralora/layer.py

Lines changed: 206 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-present the HuggingFace Inc. team.
1+
# Copyright 2025-present the HuggingFace Inc. team.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,13 +13,15 @@
1313
# limitations under the License.
1414

1515
import math
16+
import warnings
1617
from typing import Optional
1718

1819
import torch
1920
import torch.nn as nn
2021
from transformers.pytorch_utils import Conv1D
2122

2223
from peft.tuners.tuners_utils import BaseTunerLayer
24+
from peft.utils.other import transpose
2325

2426

2527
class GraloraLayer(BaseTunerLayer):
@@ -62,7 +64,7 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio
6264
"""
6365
Move the adapter of the given name to the device of the base layer.
6466
"""
65-
from peft.tuners.vera.buffer_dict import BufferDict
67+
from peft.tuners._buffer_dict import BufferDict
6668

6769
if device is None:
6870
# check weight and qweight (for GPTQ)
@@ -113,6 +115,7 @@ def update_layer(
113115
gralora_dropout,
114116
gralora_k: int = 2,
115117
hybrid_r: int = 0,
118+
init_weights: bool = True,
116119
):
117120
if r <= 0:
118121
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
@@ -141,21 +144,34 @@ def update_layer(
141144
for _ in range(gralora_k):
142145
new_A = nn.Parameter(torch.zeros(gralora_r, subblock_in_features))
143146
new_B = nn.Parameter(torch.zeros(subblock_out_features, gralora_r))
144-
nn.init.kaiming_uniform_(new_A, a=math.sqrt(5))
147+
if init_weights:
148+
# Initialize to identity: A is random, B is zero
149+
nn.init.kaiming_uniform_(new_A, a=math.sqrt(5))
150+
# new_B is already initialized to zeros
151+
else:
152+
# Initialize to random: both A and B are random (for testing)
153+
nn.init.kaiming_uniform_(new_A, a=math.sqrt(5))
154+
nn.init.kaiming_uniform_(new_B, a=math.sqrt(5))
145155
gralora_A.append(new_A)
146156
gralora_B.append(new_B)
147157
# stack A and B and transpose to get the final shape
148-
gralora_A = torch.stack(tuple(gralora_A), dim=0) # [N, rank, in_features//N]
149-
gralora_A = gralora_A.transpose(1, 2).contiguous() # [N, in_features//N, rank]
158+
gralora_A = torch.stack(tuple(gralora_A), dim=0) # [N, gralora_r, in_features//N]
159+
gralora_A = gralora_A.transpose(1, 2).contiguous() # [N, in_features//N, gralora_r]
150160

151-
gralora_B = torch.stack(tuple(gralora_B), dim=0) # [N, out_features//N, rank]
152-
gralora_B = gralora_B.transpose(1, 2).contiguous() # [N, rank, out_features//N]
161+
gralora_B = torch.stack(tuple(gralora_B), dim=0) # [N, out_features//N, gralora_r]
162+
gralora_B = gralora_B.transpose(1, 2).contiguous() # [N, gralora_r, out_features//N]
153163

154164
if hybrid_r > 0:
155165
general_gralora_A = nn.Linear(self.in_features, hybrid_r, bias=False)
156166
general_gralora_B = nn.Linear(hybrid_r, self.out_features, bias=False)
157-
nn.init.kaiming_uniform_(general_gralora_A.weight, a=math.sqrt(5))
158-
nn.init.zeros_(general_gralora_B.weight)
167+
if init_weights:
168+
# Initialize to identity: A is random, B is zero
169+
nn.init.kaiming_uniform_(general_gralora_A.weight, a=math.sqrt(5))
170+
nn.init.zeros_(general_gralora_B.weight)
171+
else:
172+
# Initialize to random: both A and B are random (for testing)
173+
nn.init.kaiming_uniform_(general_gralora_A.weight, a=math.sqrt(5))
174+
nn.init.kaiming_uniform_(general_gralora_B.weight, a=math.sqrt(5))
159175
else:
160176
general_gralora_A = nn.Identity()
161177
general_gralora_B = nn.Identity()
@@ -185,6 +201,7 @@ def __init__(
185201
gralora_k: int = 2,
186202
hybrid_r: int = 0,
187203
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
204+
init_weights: bool = True,
188205
**kwargs,
189206
) -> None:
190207
# this gets the init from nn.Linear's super perspective, i.e. nn.Module.__init__, which should always be called
@@ -193,16 +210,176 @@ def __init__(
193210
self.fan_in_fan_out = fan_in_fan_out
194211

195212
self._active_adapter = adapter_name
196-
self.update_layer(adapter_name, module_name, r, gralora_alpha, gralora_dropout, gralora_k, hybrid_r)
213+
self.update_layer(
214+
adapter_name, module_name, r, gralora_alpha, gralora_dropout, gralora_k, hybrid_r, init_weights
215+
)
197216

198217
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
199-
raise NotImplementedError("Merging is not supported for GraloraLayer yet.")
218+
"""
219+
Merge the active adapter weights into the base weights
220+
221+
Args:
222+
safe_merge (`bool`, *optional*):
223+
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
224+
before merging the weights. This is useful if you want to check if the merge operation will produce
225+
NaNs. Defaults to `False`.
226+
adapter_names (`list[str]`, *optional*):
227+
The list of adapter names that should be merged. If None, all active adapters will be merged.
228+
Defaults to `None`.
229+
"""
230+
from peft.tuners.tuners_utils import check_adapters_to_merge
231+
232+
adapter_names = check_adapters_to_merge(self, adapter_names)
233+
if not adapter_names:
234+
# no adapter to merge
235+
return
236+
237+
for active_adapter in adapter_names:
238+
if active_adapter in self.gralora_A.keys():
239+
base_layer = self.get_base_layer()
240+
if safe_merge:
241+
# Note that safe_merge will be slower than the normal merge
242+
# because of the copy operation.
243+
orig_weights = base_layer.weight.data.clone()
244+
delta_weight = self.get_delta_weight(active_adapter)
245+
orig_weights += delta_weight
246+
247+
if not torch.isfinite(orig_weights).all():
248+
raise ValueError(
249+
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
250+
)
251+
252+
base_layer.weight.data = orig_weights
253+
else:
254+
delta_weight = self.get_delta_weight(active_adapter)
255+
base_layer.weight.data += delta_weight
256+
257+
self.merged_adapters.append(active_adapter)
200258

201259
def unmerge(self) -> None:
202-
raise NotImplementedError("Unmerging is not supported for GraloraLayer yet.")
260+
"""
261+
This method unmerges all merged adapter layers from the base weights.
262+
"""
263+
if not self.merged:
264+
warnings.warn("Already unmerged. Nothing to do.")
265+
return
266+
267+
while len(self.merged_adapters) > 0:
268+
active_adapter = self.merged_adapters.pop()
269+
if active_adapter in self.gralora_A.keys():
270+
delta_weight = self.get_delta_weight(active_adapter)
271+
self.get_base_layer().weight.data -= delta_weight
203272

204273
def get_delta_weight(self, adapter) -> torch.Tensor:
205-
raise NotImplementedError("Getting delta weight is not supported for GraloraLayer yet.")
274+
"""
275+
Compute the delta weight for GraLoRA adapter.
276+
277+
GraLoRA applies block-wise low-rank adaptation with information exchange.
278+
This method computes the equivalent weight matrix that would be added to
279+
the base weight during merge.
280+
281+
Args:
282+
adapter (str): The name of the adapter
283+
284+
Returns:
285+
torch.Tensor: The delta weight matrix with shape [out_features, in_features]
286+
"""
287+
gralora_A = self.gralora_A[adapter] # [N, in_features//N, rank]
288+
gralora_B = self.gralora_B[adapter] # [N, rank, out_features//N]
289+
gralora_A_general = self.gralora_A_general[adapter]
290+
gralora_B_general = self.gralora_B_general[adapter]
291+
292+
device = gralora_A.device
293+
dtype = gralora_A.dtype
294+
295+
gralora_k = self.gralora_k[adapter]
296+
hybrid_r = self.hybrid_r[adapter]
297+
r = self.r[adapter]
298+
299+
# Handle CPU fp16/bf16 casting
300+
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
301+
302+
if cast_to_fp32:
303+
gralora_A = gralora_A.float()
304+
gralora_B = gralora_B.float()
305+
306+
# Get dimensions
307+
in_features = self.in_features
308+
out_features = self.out_features
309+
subblock_in = in_features // gralora_k
310+
subblock_out = out_features // gralora_k
311+
gralora_rank = r - hybrid_r
312+
subblock_gralora_rank = gralora_rank // gralora_k
313+
314+
# Simulate the forward pass computation to get equivalent weight matrix
315+
# We need to compute: W_delta such that W_delta @ x = gralora_forward(x) - base_forward(x)
316+
317+
# Create an identity matrix for each input dimension and compute output
318+
# This gives us the columns of the weight matrix
319+
delta_weight = torch.zeros(out_features, in_features, device=device, dtype=gralora_A.dtype)
320+
321+
# Process in batches to avoid memory issues
322+
batch_size = min(256, in_features)
323+
for start_idx in range(0, in_features, batch_size):
324+
end_idx = min(start_idx + batch_size, in_features)
325+
batch_len = end_idx - start_idx
326+
327+
# Create identity input: [batch_len, in_features]
328+
x = torch.zeros(batch_len, in_features, device=device, dtype=gralora_A.dtype)
329+
for i in range(batch_len):
330+
x[i, start_idx + i] = 1.0
331+
332+
# Apply GraLoRA transformation (following forward logic)
333+
# x shape: [batch_len, in_features]
334+
N = gralora_k
335+
336+
# Reshape x: [batch_len, N, in_features//N]
337+
x_reshaped = x.view(batch_len, N, in_features // N)
338+
339+
# Apply gralora_A: [batch_len, N, in_features//N] @ [N, in_features//N, rank]
340+
# Result: [batch_len, N, rank]
341+
temp = torch.einsum("bni, nir -> bnr", x_reshaped, gralora_A)
342+
343+
# Reshape and permute for information exchange
344+
# [batch_len, N, rank] -> [batch_len, N, N, subblock_rank]
345+
temp = temp.view(batch_len, N, N, subblock_gralora_rank)
346+
# Permute: [batch_len, N, N, subblock_rank] -> [batch_len, N, N, subblock_rank]
347+
temp = temp.permute(0, 2, 1, 3)
348+
# Reshape: [batch_len, N, N * subblock_rank]
349+
temp = temp.reshape(batch_len, N, N * subblock_gralora_rank)
350+
351+
# Apply gralora_B: [batch_len, N, N*subblock_rank] @ [N, rank, out_features//N]
352+
# Note: rank here is actually gralora_rank = N * subblock_gralora_rank
353+
# Result: [batch_len, N, out_features//N]
354+
output = torch.einsum("bnr, nro -> bno", temp, gralora_B)
355+
356+
# Reshape to [batch_len, out_features]
357+
output = output.reshape(batch_len, out_features)
358+
359+
# Store in delta_weight (transpose because weight is [out, in])
360+
delta_weight[:, start_idx:end_idx] = output.T
361+
362+
# Add hybrid LoRA component if present
363+
if hybrid_r > 0:
364+
# general_A: [in_features, hybrid_r], general_B: [hybrid_r, out_features]
365+
weight_A_general = gralora_A_general.weight # [hybrid_r, in_features]
366+
weight_B_general = gralora_B_general.weight # [out_features, hybrid_r]
367+
368+
if cast_to_fp32:
369+
weight_A_general = weight_A_general.float()
370+
weight_B_general = weight_B_general.float()
371+
372+
# Compute delta for hybrid part: [out_features, hybrid_r] @ [hybrid_r, in_features]
373+
delta_weight += weight_B_general @ weight_A_general
374+
375+
# Apply scaling and transpose if needed
376+
delta_weight = transpose(delta_weight, self.fan_in_fan_out) * self.scaling[adapter]
377+
378+
# Cast back if needed
379+
if cast_to_fp32:
380+
delta_weight = delta_weight.to(dtype=dtype)
381+
382+
return delta_weight
206383

207384
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
208385
previous_dtype = x.dtype
@@ -216,6 +393,13 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
216393
else:
217394
result = self.base_layer(x, *args, **kwargs)
218395
torch_result_dtype = result.dtype
396+
397+
# Handle 2D input: [batch, features] -> [batch, 1, features]
398+
# This is common for MLPs and other non-sequence models
399+
x_is_2d = x.ndim == 2
400+
if x_is_2d:
401+
x = x.unsqueeze(1) # [B, F] -> [B, 1, F]
402+
219403
for active_adapter in self.active_adapters:
220404
if active_adapter not in self.gralora_A.keys():
221405
continue
@@ -253,11 +437,17 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
253437
.reshape(B, L, N, N * subblock_gralora_rank),
254438
gralora_B,
255439
).reshape(B, L, -1)
440+
441+
# Squeeze back to 2D if input was 2D
442+
if x_is_2d:
443+
output = output.squeeze(1) # [B, 1, F] -> [B, F]
444+
256445
result += scaling * output.to(torch_result_dtype)
257446
if hybrid_r > 0:
258-
result += scaling * gralora_B_general(gralora_A_general(dropout(x.to(gralora_dtype)))).to(
259-
torch_result_dtype
260-
)
447+
hybrid_output = gralora_B_general(gralora_A_general(dropout(x.to(gralora_dtype))))
448+
if x_is_2d:
449+
hybrid_output = hybrid_output.squeeze(1)
450+
result += scaling * hybrid_output.to(torch_result_dtype)
261451

262452
result = result.to(previous_dtype)
263453
return result

0 commit comments

Comments
 (0)