@@ -349,14 +349,19 @@ def parse_args(input_args=None):
349349 "--optimizer" ,
350350 type = str ,
351351 default = "AdamW" ,
352- help = ( 'The optimizer type to use. Choose between ["AdamW", "prodigy"]' ) ,
352+ choices = ["AdamW" , "Prodigy" , "AdEMAMix" ] ,
353353 )
354354
355355 parser .add_argument (
356356 "--use_8bit_adam" ,
357357 action = "store_true" ,
358358 help = "Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW" ,
359359 )
360+ parser .add_argument (
361+ "--use_8bit_ademamix" ,
362+ action = "store_true" ,
363+ help = "Whether or not to use 8-bit AdEMAMix from bitsandbytes." ,
364+ )
360365
361366 parser .add_argument (
362367 "--adam_beta1" , type = float , default = 0.9 , help = "The beta1 parameter for the Adam and Prodigy optimizers."
@@ -820,16 +825,15 @@ def load_model_hook(models, input_dir):
820825 params_to_optimize = [transformer_parameters_with_lr ]
821826
822827 # Optimizer creation
823- if not ( args .optimizer . lower () == "prodigy" or args .optimizer .lower () == "adamw" ) :
828+ if args .use_8bit_adam and not args .optimizer .lower () == "adamw" :
824829 logger .warning (
825- f"Unsupported choice of optimizer: { args . optimizer } .Supported optimizers include [adamW, prodigy]. "
826- "Defaulting to adamW "
830+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
831+ f"set to { args . optimizer . lower () } "
827832 )
828- args .optimizer = "adamw"
829833
830- if args .use_8bit_adam and not args .optimizer .lower () == "adamw " :
834+ if args .use_8bit_ademamix and not args .optimizer .lower () == "ademamix " :
831835 logger .warning (
832- f"use_8bit_adam is ignored when optimizer is not set to 'AdamW '. Optimizer was "
836+ f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix '. Optimizer was "
833837 f"set to { args .optimizer .lower ()} "
834838 )
835839
@@ -853,6 +857,20 @@ def load_model_hook(models, input_dir):
853857 eps = args .adam_epsilon ,
854858 )
855859
860+ elif args .optimizer .lower () == "ademamix" :
861+ try :
862+ import bitsandbytes as bnb
863+ except ImportError :
864+ raise ImportError (
865+ "To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
866+ )
867+ if args .use_8bit_ademamix :
868+ optimizer_class = bnb .optim .AdEMAMix8bit
869+ else :
870+ optimizer_class = bnb .optim .AdEMAMix
871+
872+ optimizer = optimizer_class (params_to_optimize )
873+
856874 if args .optimizer .lower () == "prodigy" :
857875 try :
858876 import prodigyopt
@@ -868,7 +886,6 @@ def load_model_hook(models, input_dir):
868886
869887 optimizer = optimizer_class (
870888 params_to_optimize ,
871- lr = args .learning_rate ,
872889 betas = (args .adam_beta1 , args .adam_beta2 ),
873890 beta3 = args .prodigy_beta3 ,
874891 weight_decay = args .adam_weight_decay ,
@@ -1020,12 +1037,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10201037 model_input = (model_input - vae_config_shift_factor ) * vae_config_scaling_factor
10211038 model_input = model_input .to (dtype = weight_dtype )
10221039
1023- vae_scale_factor = 2 ** (len (vae_config_block_out_channels ))
1040+ vae_scale_factor = 2 ** (len (vae_config_block_out_channels ) - 1 )
10241041
10251042 latent_image_ids = FluxPipeline ._prepare_latent_image_ids (
10261043 model_input .shape [0 ],
1027- model_input .shape [2 ],
1028- model_input .shape [3 ],
1044+ model_input .shape [2 ] // 2 ,
1045+ model_input .shape [3 ] // 2 ,
10291046 accelerator .device ,
10301047 weight_dtype ,
10311048 )
@@ -1059,7 +1076,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10591076 )
10601077
10611078 # handle guidance
1062- if transformer .config .guidance_embeds :
1079+ if unwrap_model ( transformer ) .config .guidance_embeds :
10631080 guidance = torch .tensor ([args .guidance_scale ], device = accelerator .device )
10641081 guidance = guidance .expand (model_input .shape [0 ])
10651082 else :
@@ -1082,8 +1099,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10821099 )[0 ]
10831100 model_pred = FluxPipeline ._unpack_latents (
10841101 model_pred ,
1085- height = int ( model_input .shape [2 ] * vae_scale_factor / 2 ) ,
1086- width = int ( model_input .shape [3 ] * vae_scale_factor / 2 ) ,
1102+ height = model_input .shape [2 ] * vae_scale_factor ,
1103+ width = model_input .shape [3 ] * vae_scale_factor ,
10871104 vae_scale_factor = vae_scale_factor ,
10881105 )
10891106
0 commit comments