2929import torch
3030import torch .utils .checkpoint
3131import transformers
32- from accelerate import Accelerator
32+ from accelerate import Accelerator , DistributedType
3333from accelerate .logging import get_logger
3434from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
3535from huggingface_hub import create_repo , upload_folder
@@ -1292,11 +1292,17 @@ def save_model_hook(models, weights, output_dir):
12921292 text_encoder_two_lora_layers_to_save = None
12931293
12941294 for model in models :
1295- if isinstance (model , type (unwrap_model (transformer ))):
1295+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1296+ model = unwrap_model (model )
1297+ if args .upcast_before_saving :
1298+ model = model .to (torch .float32 )
12961299 transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1297- elif isinstance (model , type (unwrap_model (text_encoder_one ))): # or text_encoder_two
1300+ elif args .train_text_encoder and isinstance (
1301+ unwrap_model (model ), type (unwrap_model (text_encoder_one ))
1302+ ): # or text_encoder_two
12981303 # both text encoders are of the same class, so we check hidden size to distinguish between the two
1299- hidden_size = unwrap_model (model ).config .hidden_size
1304+ model = unwrap_model (model )
1305+ hidden_size = model .config .hidden_size
13001306 if hidden_size == 768 :
13011307 text_encoder_one_lora_layers_to_save = get_peft_model_state_dict (model )
13021308 elif hidden_size == 1280 :
@@ -1305,7 +1311,8 @@ def save_model_hook(models, weights, output_dir):
13051311 raise ValueError (f"unexpected save model: { model .__class__ } " )
13061312
13071313 # make sure to pop weight so that corresponding model is not saved again
1308- weights .pop ()
1314+ if weights :
1315+ weights .pop ()
13091316
13101317 StableDiffusion3Pipeline .save_lora_weights (
13111318 output_dir ,
@@ -1319,17 +1326,31 @@ def load_model_hook(models, input_dir):
13191326 text_encoder_one_ = None
13201327 text_encoder_two_ = None
13211328
1322- while len (models ) > 0 :
1323- model = models .pop ()
1329+ if not accelerator .distributed_type == DistributedType .DEEPSPEED :
1330+ while len (models ) > 0 :
1331+ model = models .pop ()
13241332
1325- if isinstance (model , type (unwrap_model (transformer ))):
1326- transformer_ = model
1327- elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1328- text_encoder_one_ = model
1329- elif isinstance (model , type (unwrap_model (text_encoder_two ))):
1330- text_encoder_two_ = model
1331- else :
1332- raise ValueError (f"unexpected save model: { model .__class__ } " )
1333+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1334+ transformer_ = unwrap_model (model )
1335+ elif isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_one ))):
1336+ text_encoder_one_ = unwrap_model (model )
1337+ elif isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_two ))):
1338+ text_encoder_two_ = unwrap_model (model )
1339+ else :
1340+ raise ValueError (f"unexpected save model: { model .__class__ } " )
1341+
1342+ else :
1343+ transformer_ = SD3Transformer2DModel .from_pretrained (
1344+ args .pretrained_model_name_or_path , subfolder = "transformer"
1345+ )
1346+ transformer_ .add_adapter (transformer_lora_config )
1347+ if args .train_text_encoder :
1348+ text_encoder_one_ = text_encoder_cls_one .from_pretrained (
1349+ args .pretrained_model_name_or_path , subfolder = "text_encoder"
1350+ )
1351+ text_encoder_two_ = text_encoder_cls_two .from_pretrained (
1352+ args .pretrained_model_name_or_path , subfolder = "text_encoder_2"
1353+ )
13331354
13341355 lora_state_dict = StableDiffusion3Pipeline .lora_state_dict (input_dir )
13351356
@@ -1829,7 +1850,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18291850 progress_bar .update (1 )
18301851 global_step += 1
18311852
1832- if accelerator .is_main_process :
1853+ if accelerator .is_main_process or accelerator . distributed_type == DistributedType . DEEPSPEED :
18331854 if global_step % args .checkpointing_steps == 0 :
18341855 # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
18351856 if args .checkpoints_total_limit is not None :
0 commit comments