1- # Copyright 2023 -present the HuggingFace Inc. team.
1+ # Copyright 2025 -present the HuggingFace Inc. team.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1313# limitations under the License.
1414
1515import math
16+ import warnings
1617from typing import Optional
1718
1819import torch
1920import torch .nn as nn
2021from transformers .pytorch_utils import Conv1D
2122
2223from peft .tuners .tuners_utils import BaseTunerLayer
24+ from peft .utils .other import transpose
2325
2426
2527class GraloraLayer (BaseTunerLayer ):
@@ -62,7 +64,7 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio
6264 """
6365 Move the adapter of the given name to the device of the base layer.
6466 """
65- from peft .tuners .vera . buffer_dict import BufferDict
67+ from peft .tuners ._buffer_dict import BufferDict
6668
6769 if device is None :
6870 # check weight and qweight (for GPTQ)
@@ -113,6 +115,7 @@ def update_layer(
113115 gralora_dropout ,
114116 gralora_k : int = 2 ,
115117 hybrid_r : int = 0 ,
118+ init_weights : bool = True ,
116119 ):
117120 if r <= 0 :
118121 raise ValueError (f"`r` should be a positive integer value but the value passed is { r } " )
@@ -141,21 +144,34 @@ def update_layer(
141144 for _ in range (gralora_k ):
142145 new_A = nn .Parameter (torch .zeros (gralora_r , subblock_in_features ))
143146 new_B = nn .Parameter (torch .zeros (subblock_out_features , gralora_r ))
144- nn .init .kaiming_uniform_ (new_A , a = math .sqrt (5 ))
147+ if init_weights :
148+ # Initialize to identity: A is random, B is zero
149+ nn .init .kaiming_uniform_ (new_A , a = math .sqrt (5 ))
150+ # new_B is already initialized to zeros
151+ else :
152+ # Initialize to random: both A and B are random (for testing)
153+ nn .init .kaiming_uniform_ (new_A , a = math .sqrt (5 ))
154+ nn .init .kaiming_uniform_ (new_B , a = math .sqrt (5 ))
145155 gralora_A .append (new_A )
146156 gralora_B .append (new_B )
147157 # stack A and B and transpose to get the final shape
148- gralora_A = torch .stack (tuple (gralora_A ), dim = 0 ) # [N, rank , in_features//N]
149- gralora_A = gralora_A .transpose (1 , 2 ).contiguous () # [N, in_features//N, rank ]
158+ gralora_A = torch .stack (tuple (gralora_A ), dim = 0 ) # [N, gralora_r , in_features//N]
159+ gralora_A = gralora_A .transpose (1 , 2 ).contiguous () # [N, in_features//N, gralora_r ]
150160
151- gralora_B = torch .stack (tuple (gralora_B ), dim = 0 ) # [N, out_features//N, rank ]
152- gralora_B = gralora_B .transpose (1 , 2 ).contiguous () # [N, rank , out_features//N]
161+ gralora_B = torch .stack (tuple (gralora_B ), dim = 0 ) # [N, out_features//N, gralora_r ]
162+ gralora_B = gralora_B .transpose (1 , 2 ).contiguous () # [N, gralora_r , out_features//N]
153163
154164 if hybrid_r > 0 :
155165 general_gralora_A = nn .Linear (self .in_features , hybrid_r , bias = False )
156166 general_gralora_B = nn .Linear (hybrid_r , self .out_features , bias = False )
157- nn .init .kaiming_uniform_ (general_gralora_A .weight , a = math .sqrt (5 ))
158- nn .init .zeros_ (general_gralora_B .weight )
167+ if init_weights :
168+ # Initialize to identity: A is random, B is zero
169+ nn .init .kaiming_uniform_ (general_gralora_A .weight , a = math .sqrt (5 ))
170+ nn .init .zeros_ (general_gralora_B .weight )
171+ else :
172+ # Initialize to random: both A and B are random (for testing)
173+ nn .init .kaiming_uniform_ (general_gralora_A .weight , a = math .sqrt (5 ))
174+ nn .init .kaiming_uniform_ (general_gralora_B .weight , a = math .sqrt (5 ))
159175 else :
160176 general_gralora_A = nn .Identity ()
161177 general_gralora_B = nn .Identity ()
@@ -185,6 +201,7 @@ def __init__(
185201 gralora_k : int = 2 ,
186202 hybrid_r : int = 0 ,
187203 fan_in_fan_out : bool = False , # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
204+ init_weights : bool = True ,
188205 ** kwargs ,
189206 ) -> None :
190207 # this gets the init from nn.Linear's super perspective, i.e. nn.Module.__init__, which should always be called
@@ -193,16 +210,176 @@ def __init__(
193210 self .fan_in_fan_out = fan_in_fan_out
194211
195212 self ._active_adapter = adapter_name
196- self .update_layer (adapter_name , module_name , r , gralora_alpha , gralora_dropout , gralora_k , hybrid_r )
213+ self .update_layer (
214+ adapter_name , module_name , r , gralora_alpha , gralora_dropout , gralora_k , hybrid_r , init_weights
215+ )
197216
198217 def merge (self , safe_merge : bool = False , adapter_names : Optional [list [str ]] = None ) -> None :
199- raise NotImplementedError ("Merging is not supported for GraloraLayer yet." )
218+ """
219+ Merge the active adapter weights into the base weights
220+
221+ Args:
222+ safe_merge (`bool`, *optional*):
223+ If True, the merge operation will be performed in a copy of the original weights and check for NaNs
224+ before merging the weights. This is useful if you want to check if the merge operation will produce
225+ NaNs. Defaults to `False`.
226+ adapter_names (`list[str]`, *optional*):
227+ The list of adapter names that should be merged. If None, all active adapters will be merged.
228+ Defaults to `None`.
229+ """
230+ from peft .tuners .tuners_utils import check_adapters_to_merge
231+
232+ adapter_names = check_adapters_to_merge (self , adapter_names )
233+ if not adapter_names :
234+ # no adapter to merge
235+ return
236+
237+ for active_adapter in adapter_names :
238+ if active_adapter in self .gralora_A .keys ():
239+ base_layer = self .get_base_layer ()
240+ if safe_merge :
241+ # Note that safe_merge will be slower than the normal merge
242+ # because of the copy operation.
243+ orig_weights = base_layer .weight .data .clone ()
244+ delta_weight = self .get_delta_weight (active_adapter )
245+ orig_weights += delta_weight
246+
247+ if not torch .isfinite (orig_weights ).all ():
248+ raise ValueError (
249+ f"NaNs detected in the merged weights. The adapter { active_adapter } seems to be broken"
250+ )
251+
252+ base_layer .weight .data = orig_weights
253+ else :
254+ delta_weight = self .get_delta_weight (active_adapter )
255+ base_layer .weight .data += delta_weight
256+
257+ self .merged_adapters .append (active_adapter )
200258
201259 def unmerge (self ) -> None :
202- raise NotImplementedError ("Unmerging is not supported for GraloraLayer yet." )
260+ """
261+ This method unmerges all merged adapter layers from the base weights.
262+ """
263+ if not self .merged :
264+ warnings .warn ("Already unmerged. Nothing to do." )
265+ return
266+
267+ while len (self .merged_adapters ) > 0 :
268+ active_adapter = self .merged_adapters .pop ()
269+ if active_adapter in self .gralora_A .keys ():
270+ delta_weight = self .get_delta_weight (active_adapter )
271+ self .get_base_layer ().weight .data -= delta_weight
203272
204273 def get_delta_weight (self , adapter ) -> torch .Tensor :
205- raise NotImplementedError ("Getting delta weight is not supported for GraloraLayer yet." )
274+ """
275+ Compute the delta weight for GraLoRA adapter.
276+
277+ GraLoRA applies block-wise low-rank adaptation with information exchange.
278+ This method computes the equivalent weight matrix that would be added to
279+ the base weight during merge.
280+
281+ Args:
282+ adapter (str): The name of the adapter
283+
284+ Returns:
285+ torch.Tensor: The delta weight matrix with shape [out_features, in_features]
286+ """
287+ gralora_A = self .gralora_A [adapter ] # [N, in_features//N, rank]
288+ gralora_B = self .gralora_B [adapter ] # [N, rank, out_features//N]
289+ gralora_A_general = self .gralora_A_general [adapter ]
290+ gralora_B_general = self .gralora_B_general [adapter ]
291+
292+ device = gralora_A .device
293+ dtype = gralora_A .dtype
294+
295+ gralora_k = self .gralora_k [adapter ]
296+ hybrid_r = self .hybrid_r [adapter ]
297+ r = self .r [adapter ]
298+
299+ # Handle CPU fp16/bf16 casting
300+ cast_to_fp32 = device .type == "cpu" and (dtype == torch .float16 or dtype == torch .bfloat16 )
301+
302+ if cast_to_fp32 :
303+ gralora_A = gralora_A .float ()
304+ gralora_B = gralora_B .float ()
305+
306+ # Get dimensions
307+ in_features = self .in_features
308+ out_features = self .out_features
309+ subblock_in = in_features // gralora_k
310+ subblock_out = out_features // gralora_k
311+ gralora_rank = r - hybrid_r
312+ subblock_gralora_rank = gralora_rank // gralora_k
313+
314+ # Simulate the forward pass computation to get equivalent weight matrix
315+ # We need to compute: W_delta such that W_delta @ x = gralora_forward(x) - base_forward(x)
316+
317+ # Create an identity matrix for each input dimension and compute output
318+ # This gives us the columns of the weight matrix
319+ delta_weight = torch .zeros (out_features , in_features , device = device , dtype = gralora_A .dtype )
320+
321+ # Process in batches to avoid memory issues
322+ batch_size = min (256 , in_features )
323+ for start_idx in range (0 , in_features , batch_size ):
324+ end_idx = min (start_idx + batch_size , in_features )
325+ batch_len = end_idx - start_idx
326+
327+ # Create identity input: [batch_len, in_features]
328+ x = torch .zeros (batch_len , in_features , device = device , dtype = gralora_A .dtype )
329+ for i in range (batch_len ):
330+ x [i , start_idx + i ] = 1.0
331+
332+ # Apply GraLoRA transformation (following forward logic)
333+ # x shape: [batch_len, in_features]
334+ N = gralora_k
335+
336+ # Reshape x: [batch_len, N, in_features//N]
337+ x_reshaped = x .view (batch_len , N , in_features // N )
338+
339+ # Apply gralora_A: [batch_len, N, in_features//N] @ [N, in_features//N, rank]
340+ # Result: [batch_len, N, rank]
341+ temp = torch .einsum ("bni, nir -> bnr" , x_reshaped , gralora_A )
342+
343+ # Reshape and permute for information exchange
344+ # [batch_len, N, rank] -> [batch_len, N, N, subblock_rank]
345+ temp = temp .view (batch_len , N , N , subblock_gralora_rank )
346+ # Permute: [batch_len, N, N, subblock_rank] -> [batch_len, N, N, subblock_rank]
347+ temp = temp .permute (0 , 2 , 1 , 3 )
348+ # Reshape: [batch_len, N, N * subblock_rank]
349+ temp = temp .reshape (batch_len , N , N * subblock_gralora_rank )
350+
351+ # Apply gralora_B: [batch_len, N, N*subblock_rank] @ [N, rank, out_features//N]
352+ # Note: rank here is actually gralora_rank = N * subblock_gralora_rank
353+ # Result: [batch_len, N, out_features//N]
354+ output = torch .einsum ("bnr, nro -> bno" , temp , gralora_B )
355+
356+ # Reshape to [batch_len, out_features]
357+ output = output .reshape (batch_len , out_features )
358+
359+ # Store in delta_weight (transpose because weight is [out, in])
360+ delta_weight [:, start_idx :end_idx ] = output .T
361+
362+ # Add hybrid LoRA component if present
363+ if hybrid_r > 0 :
364+ # general_A: [in_features, hybrid_r], general_B: [hybrid_r, out_features]
365+ weight_A_general = gralora_A_general .weight # [hybrid_r, in_features]
366+ weight_B_general = gralora_B_general .weight # [out_features, hybrid_r]
367+
368+ if cast_to_fp32 :
369+ weight_A_general = weight_A_general .float ()
370+ weight_B_general = weight_B_general .float ()
371+
372+ # Compute delta for hybrid part: [out_features, hybrid_r] @ [hybrid_r, in_features]
373+ delta_weight += weight_B_general @ weight_A_general
374+
375+ # Apply scaling and transpose if needed
376+ delta_weight = transpose (delta_weight , self .fan_in_fan_out ) * self .scaling [adapter ]
377+
378+ # Cast back if needed
379+ if cast_to_fp32 :
380+ delta_weight = delta_weight .to (dtype = dtype )
381+
382+ return delta_weight
206383
207384 def forward (self , x : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
208385 previous_dtype = x .dtype
@@ -216,6 +393,13 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
216393 else :
217394 result = self .base_layer (x , * args , ** kwargs )
218395 torch_result_dtype = result .dtype
396+
397+ # Handle 2D input: [batch, features] -> [batch, 1, features]
398+ # This is common for MLPs and other non-sequence models
399+ x_is_2d = x .ndim == 2
400+ if x_is_2d :
401+ x = x .unsqueeze (1 ) # [B, F] -> [B, 1, F]
402+
219403 for active_adapter in self .active_adapters :
220404 if active_adapter not in self .gralora_A .keys ():
221405 continue
@@ -253,11 +437,17 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
253437 .reshape (B , L , N , N * subblock_gralora_rank ),
254438 gralora_B ,
255439 ).reshape (B , L , - 1 )
440+
441+ # Squeeze back to 2D if input was 2D
442+ if x_is_2d :
443+ output = output .squeeze (1 ) # [B, 1, F] -> [B, F]
444+
256445 result += scaling * output .to (torch_result_dtype )
257446 if hybrid_r > 0 :
258- result += scaling * gralora_B_general (gralora_A_general (dropout (x .to (gralora_dtype )))).to (
259- torch_result_dtype
260- )
447+ hybrid_output = gralora_B_general (gralora_A_general (dropout (x .to (gralora_dtype ))))
448+ if x_is_2d :
449+ hybrid_output = hybrid_output .squeeze (1 )
450+ result += scaling * hybrid_output .to (torch_result_dtype )
261451
262452 result = result .to (previous_dtype )
263453 return result
0 commit comments