@@ -71,15 +71,16 @@ def merge_lora_weights(self) -> None:
7171                f"cuda:{ torch .cuda .current_device ()}  ).full_tensor ()
7272            data  +=  (self .slice_lora_b_weights (self .lora_B )
7373                     @ self .slice_lora_a_weights (self .lora_A )).to (data )
74-             self .base_layer .weight .data  =  distribute_tensor (
75-                 data , mesh , placements = placements ).to (current_device )
74+             self .base_layer .weight  =  nn .Parameter (
75+                 distribute_tensor (data , mesh ,
76+                                   placements = placements ).to (current_device ))
7677        else :
7778            current_device  =  self .base_layer .weight .data .device 
78-             data  =  self .base_layer .weight .data . to (
79+             data  =  self .base_layer .weight .to (
7980                f"cuda:{ torch .cuda .current_device ()}  )
8081            data  +=  \
8182                (self .slice_lora_b_weights (self .lora_B ) @ self .slice_lora_a_weights (self .lora_A )).to (data )
82-             self .base_layer .weight . data  =  data .to (current_device )
83+             self .base_layer .weight  =  nn . Parameter ( data .to (current_device ) )
8384        self .merged  =  True 
8485
8586    @torch .no_grad () 
@@ -106,8 +107,8 @@ def unmerge_lora_weights(self) -> None:
106107                f"cuda:{ torch .cuda .current_device ()}  ).full_tensor ()
107108            data  -=  self .slice_lora_b_weights (
108109                self .lora_B ) @ self .slice_lora_a_weights (self .lora_A )
109-             self .base_layer .weight . data  =  distribute_tensor (
110-                 data , mesh , placements = placement ).to (device )
110+             self .base_layer .weight  =  nn . Parameter (
111+                 distribute_tensor ( data , mesh , placements = placement ).to (device ) )
111112        else :
112113            self .base_layer .weight .data  -=  \
113114                self .slice_lora_b_weights (self .lora_B ) @\
0 commit comments