2929import numpy as np
3030import torch
3131import transformers
32- from accelerate import Accelerator
32+ from accelerate import Accelerator , DistributedType
3333from accelerate .logging import get_logger
3434from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
3535from huggingface_hub import create_repo , upload_folder
@@ -1181,13 +1181,15 @@ def save_model_hook(models, weights, output_dir):
11811181 transformer_lora_layers_to_save = None
11821182
11831183 for model in models :
1184- if isinstance (model , type (unwrap_model (transformer ))):
1184+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1185+ model = unwrap_model (model )
11851186 transformer_lora_layers_to_save = get_peft_model_state_dict (model )
11861187 else :
11871188 raise ValueError (f"unexpected save model: { model .__class__ } " )
11881189
11891190 # make sure to pop weight so that corresponding model is not saved again
1190- weights .pop ()
1191+ if weights :
1192+ weights .pop ()
11911193
11921194 HiDreamImagePipeline .save_lora_weights (
11931195 output_dir ,
@@ -1197,13 +1199,20 @@ def save_model_hook(models, weights, output_dir):
11971199 def load_model_hook (models , input_dir ):
11981200 transformer_ = None
11991201
1200- while len (models ) > 0 :
1201- model = models .pop ()
1202+ if not accelerator .distributed_type == DistributedType .DEEPSPEED :
1203+ while len (models ) > 0 :
1204+ model = models .pop ()
12021205
1203- if isinstance (model , type (unwrap_model (transformer ))):
1204- transformer_ = model
1205- else :
1206- raise ValueError (f"unexpected save model: { model .__class__ } " )
1206+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1207+ model = unwrap_model (model )
1208+ transformer_ = model
1209+ else :
1210+ raise ValueError (f"unexpected save model: { model .__class__ } " )
1211+ else :
1212+ transformer_ = HiDreamImageTransformer2DModel .from_pretrained (
1213+ args .pretrained_model_name_or_path , subfolder = "transformer"
1214+ )
1215+ transformer_ .add_adapter (transformer_lora_config )
12071216
12081217 lora_state_dict = HiDreamImagePipeline .lora_state_dict (input_dir )
12091218
@@ -1655,7 +1664,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16551664 progress_bar .update (1 )
16561665 global_step += 1
16571666
1658- if accelerator .is_main_process :
1667+ if accelerator .is_main_process or accelerator . distributed_type == DistributedType . DEEPSPEED :
16591668 if global_step % args .checkpointing_steps == 0 :
16601669 # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
16611670 if args .checkpoints_total_limit is not None :
0 commit comments