Skip to content

Commit a18ba67

Browse files
authored
ENH Updates to OFT (huggingface#2805)
- better speed - improve Cayley-Neumann parameterization - merging numerically more stable Note that these changes result in slightly different outputs from OFT, so it is recommended to retrain OFT checkpoints when upgrading to PEFT >= 0.18.0.
1 parent 546927d commit a18ba67

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

src/peft/tuners/oft/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
from __future__ import annotations
1616

17+
import warnings
1718
from dataclasses import dataclass, field
1819
from typing import Literal, Optional, Union
1920

21+
import packaging.version
22+
2023
from peft.config import PeftConfig
2124
from peft.utils import PeftType
2225

@@ -193,4 +196,18 @@ def check_kwargs(cls, **kwargs):
193196
"with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. "
194197
"Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights."
195198
)
199+
if kwargs.get("use_cayley_neumann", False):
200+
peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version
201+
# remove commit hash, if present
202+
peft_version = peft_version.partition("@")[0]
203+
parsed_version = packaging.version.Version(peft_version)
204+
min_version = packaging.version.Version("0.18.0")
205+
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
206+
if parsed_version < min_version:
207+
msg = (
208+
"The cayley-neumann parameterization has been slightly changed to be more numerically stable in "
209+
"PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, "
210+
"downgrade PEFT to version 0.17.0 to use the old parameterization."
211+
)
212+
warnings.warn(msg)
196213
return super().check_kwargs(**kwargs)

src/peft/tuners/oft/layer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def __init__(
9797
self.use_cayley_neumann = use_cayley_neumann
9898
self.num_cayley_neumann_terms = num_cayley_neumann_terms
9999
# Create indices for upper triangle (excluding diagonal)
100-
self.rows, self.cols = torch.triu_indices(block_size, block_size, 1)
100+
rows, cols = torch.triu_indices(block_size, block_size, 1)
101+
self.register_buffer("rows", rows, persistent=False)
102+
self.register_buffer("cols", cols, persistent=False)
101103

102104
def _pytorch_skew_symmetric(self, vec, block_size):
103105
batch_size = vec.shape[0]
@@ -139,9 +141,11 @@ def _cayley_batch(
139141
R.add_(Q_squared, alpha=2.0)
140142

141143
Q_power = Q_squared
142-
for i in range(3, num_neumann_terms):
144+
for _ in range(3, num_neumann_terms - 1):
143145
Q_power = torch.bmm(Q_power, Q_skew)
144146
R.add_(Q_power, alpha=2.0)
147+
Q_power = torch.bmm(Q_power, Q_skew)
148+
R.add_(Q_power)
145149
else:
146150
id_mat = (
147151
torch.eye(Q_skew.shape[-1], device=Q_skew.device)
@@ -621,9 +625,13 @@ def unmerge(self) -> None:
621625
if active_adapter in self.oft_R.keys():
622626
oft_mat = self.get_delta_weight(active_adapter)
623627

628+
previous_dtype = oft_mat.dtype
629+
if previous_dtype != torch.float32:
630+
oft_mat = oft_mat.to(torch.float32)
631+
624632
orig_weights = self.get_base_layer().weight.data
625633
orig_weights = torch.transpose(orig_weights, 0, 1)
626-
orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype))
634+
orig_weights = torch.mm(torch.linalg.inv(oft_mat).to(previous_dtype), orig_weights.to(previous_dtype))
627635
orig_weights = torch.transpose(orig_weights, 0, 1)
628636

629637
base_layer.weight.data = orig_weights.to(orig_dtype)
@@ -855,13 +863,17 @@ def unmerge(self) -> None:
855863
if active_adapter in self.oft_R.keys():
856864
oft_mat = self.get_delta_weight(active_adapter)
857865

866+
previous_dtype = oft_mat.dtype
867+
if previous_dtype != torch.float32:
868+
oft_mat = oft_mat.to(torch.float32)
869+
858870
orig_weights = self.get_base_layer().weight.data.clone()
859871
orig_weights = orig_weights.view(
860872
self.out_features,
861873
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
862874
)
863875
orig_weights = torch.transpose(orig_weights, 0, 1)
864-
orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype))
876+
orig_weights = torch.mm(torch.linalg.inv(oft_mat).to(previous_dtype), orig_weights.to(previous_dtype))
865877
orig_weights = torch.transpose(orig_weights, 0, 1)
866878
orig_weights = orig_weights.view(
867879
self.out_features,

0 commit comments

Comments
 (0)