2727
2828import  numpy  as  np 
2929import  torch 
30- import  torch .utils .checkpoint 
3130import  transformers 
3231from  accelerate  import  Accelerator 
3332from  accelerate .logging  import  get_logger 
5352)
5453from  diffusers .optimization  import  get_scheduler 
5554from  diffusers .training_utils  import  (
55+     _collate_lora_metadata ,
5656    _set_state_dict_into_text_encoder ,
5757    cast_training_params ,
5858    compute_density_for_timestep_sampling ,
@@ -358,7 +358,12 @@ def parse_args(input_args=None):
358358        default = 4 ,
359359        help = ("The dimension of the LoRA update matrices." ),
360360    )
361- 
361+     parser .add_argument (
362+         "--lora_alpha" ,
363+         type = int ,
364+         default = 4 ,
365+         help = "LoRA alpha to be used for additional scaling." ,
366+     )
362367    parser .add_argument ("--lora_dropout" , type = float , default = 0.0 , help = "Dropout probability for LoRA layers" )
363368
364369    parser .add_argument (
@@ -1238,7 +1243,7 @@ def main(args):
12381243    # now we will add new LoRA weights the transformer layers 
12391244    transformer_lora_config  =  LoraConfig (
12401245        r = args .rank ,
1241-         lora_alpha = args .rank ,
1246+         lora_alpha = args .lora_alpha ,
12421247        lora_dropout = args .lora_dropout ,
12431248        init_lora_weights = "gaussian" ,
12441249        target_modules = target_modules ,
@@ -1247,7 +1252,7 @@ def main(args):
12471252    if  args .train_text_encoder :
12481253        text_lora_config  =  LoraConfig (
12491254            r = args .rank ,
1250-             lora_alpha = args .rank ,
1255+             lora_alpha = args .lora_alpha ,
12511256            lora_dropout = args .lora_dropout ,
12521257            init_lora_weights = "gaussian" ,
12531258            target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ],
@@ -1264,12 +1269,14 @@ def save_model_hook(models, weights, output_dir):
12641269        if  accelerator .is_main_process :
12651270            transformer_lora_layers_to_save  =  None 
12661271            text_encoder_one_lora_layers_to_save  =  None 
1267- 
1272+              modules_to_save   =  {} 
12681273            for  model  in  models :
12691274                if  isinstance (model , type (unwrap_model (transformer ))):
12701275                    transformer_lora_layers_to_save  =  get_peft_model_state_dict (model )
1276+                     modules_to_save ["transformer" ] =  model 
12711277                elif  isinstance (model , type (unwrap_model (text_encoder_one ))):
12721278                    text_encoder_one_lora_layers_to_save  =  get_peft_model_state_dict (model )
1279+                     modules_to_save ["text_encoder" ] =  model 
12731280                else :
12741281                    raise  ValueError (f"unexpected save model: { model .__class__ }  )
12751282
@@ -1280,6 +1287,7 @@ def save_model_hook(models, weights, output_dir):
12801287                output_dir ,
12811288                transformer_lora_layers = transformer_lora_layers_to_save ,
12821289                text_encoder_lora_layers = text_encoder_one_lora_layers_to_save ,
1290+                 ** _collate_lora_metadata (modules_to_save ),
12831291            )
12841292
12851293    def  load_model_hook (models , input_dir ):
@@ -1889,23 +1897,27 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18891897    # Save the lora layers 
18901898    accelerator .wait_for_everyone ()
18911899    if  accelerator .is_main_process :
1900+         modules_to_save  =  {}
18921901        transformer  =  unwrap_model (transformer )
18931902        if  args .upcast_before_saving :
18941903            transformer .to (torch .float32 )
18951904        else :
18961905            transformer  =  transformer .to (weight_dtype )
18971906        transformer_lora_layers  =  get_peft_model_state_dict (transformer )
1907+         modules_to_save ["transformer" ] =  transformer 
18981908
18991909        if  args .train_text_encoder :
19001910            text_encoder_one  =  unwrap_model (text_encoder_one )
19011911            text_encoder_lora_layers  =  get_peft_model_state_dict (text_encoder_one .to (torch .float32 ))
1912+             modules_to_save ["text_encoder" ] =  text_encoder_one 
19021913        else :
19031914            text_encoder_lora_layers  =  None 
19041915
19051916        FluxPipeline .save_lora_weights (
19061917            save_directory = args .output_dir ,
19071918            transformer_lora_layers = transformer_lora_layers ,
19081919            text_encoder_lora_layers = text_encoder_lora_layers ,
1920+             ** _collate_lora_metadata (modules_to_save ),
19091921        )
19101922
19111923        # Final inference 
0 commit comments