Skip to content

Commit 351877f

Browse files
FIX CPU casting in GraLoRA get_delta_weight function
1 parent 430e896 commit 351877f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/peft/tuners/gralora/layer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
277277
l_indices = torch.arange(in_features, device=device)
278278
n_indices = l_indices // (in_features // gralora_k)
279279
i_indices = l_indices % (in_features // gralora_k)
280-
gralora_A_scattered = torch.zeros(in_features, gralora_k, gralora_rank, device=device, dtype=dtype)
280+
gralora_A_scattered = torch.zeros(
281+
in_features, gralora_k, gralora_rank, device=device, dtype=torch.float32 if cast_to_fp32 else dtype
282+
)
281283
gralora_A_scattered.scatter_(
282284
1,
283285
n_indices.unsqueeze(1).unsqueeze(2).expand(-1, 1, gralora_rank),

0 commit comments

Comments
 (0)