@@ -97,7 +97,9 @@ def __init__(
9797 self .use_cayley_neumann = use_cayley_neumann
9898 self .num_cayley_neumann_terms = num_cayley_neumann_terms
9999 # Create indices for upper triangle (excluding diagonal)
100- self .rows , self .cols = torch .triu_indices (block_size , block_size , 1 )
100+ rows , cols = torch .triu_indices (block_size , block_size , 1 )
101+ self .register_buffer ("rows" , rows , persistent = False )
102+ self .register_buffer ("cols" , cols , persistent = False )
101103
102104 def _pytorch_skew_symmetric (self , vec , block_size ):
103105 batch_size = vec .shape [0 ]
@@ -139,9 +141,11 @@ def _cayley_batch(
139141 R .add_ (Q_squared , alpha = 2.0 )
140142
141143 Q_power = Q_squared
142- for i in range (3 , num_neumann_terms ):
144+ for _ in range (3 , num_neumann_terms - 1 ):
143145 Q_power = torch .bmm (Q_power , Q_skew )
144146 R .add_ (Q_power , alpha = 2.0 )
147+ Q_power = torch .bmm (Q_power , Q_skew )
148+ R .add_ (Q_power )
145149 else :
146150 id_mat = (
147151 torch .eye (Q_skew .shape [- 1 ], device = Q_skew .device )
@@ -621,9 +625,13 @@ def unmerge(self) -> None:
621625 if active_adapter in self .oft_R .keys ():
622626 oft_mat = self .get_delta_weight (active_adapter )
623627
628+ previous_dtype = oft_mat .dtype
629+ if previous_dtype != torch .float32 :
630+ oft_mat = oft_mat .to (torch .float32 )
631+
624632 orig_weights = self .get_base_layer ().weight .data
625633 orig_weights = torch .transpose (orig_weights , 0 , 1 )
626- orig_weights = torch .mm (oft_mat . t ( ), orig_weights .to (oft_mat . dtype ))
634+ orig_weights = torch .mm (torch . linalg . inv ( oft_mat ). to ( previous_dtype ), orig_weights .to (previous_dtype ))
627635 orig_weights = torch .transpose (orig_weights , 0 , 1 )
628636
629637 base_layer .weight .data = orig_weights .to (orig_dtype )
@@ -855,13 +863,17 @@ def unmerge(self) -> None:
855863 if active_adapter in self .oft_R .keys ():
856864 oft_mat = self .get_delta_weight (active_adapter )
857865
866+ previous_dtype = oft_mat .dtype
867+ if previous_dtype != torch .float32 :
868+ oft_mat = oft_mat .to (torch .float32 )
869+
858870 orig_weights = self .get_base_layer ().weight .data .clone ()
859871 orig_weights = orig_weights .view (
860872 self .out_features ,
861873 self .in_features * self .get_base_layer ().kernel_size [0 ] * self .get_base_layer ().kernel_size [0 ],
862874 )
863875 orig_weights = torch .transpose (orig_weights , 0 , 1 )
864- orig_weights = torch .mm (oft_mat . t ( ), orig_weights .to (oft_mat . dtype ))
876+ orig_weights = torch .mm (torch . linalg . inv ( oft_mat ). to ( previous_dtype ), orig_weights .to (previous_dtype ))
865877 orig_weights = torch .transpose (orig_weights , 0 , 1 )
866878 orig_weights = orig_weights .view (
867879 self .out_features ,
0 commit comments