1717import contextlib
1818import copy
1919import functools
20+ import gc
2021import logging
2122import math
2223import os
5253from diffusers .training_utils import compute_density_for_timestep_sampling , compute_loss_weighting_for_sd3 , free_memory
5354from diffusers .utils import check_min_version , is_wandb_available , make_image_grid
5455from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
56+ from diffusers .utils .testing_utils import backend_empty_cache
5557from diffusers .utils .torch_utils import is_compiled_module
5658
5759
@@ -74,8 +76,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
7476
7577 pipeline = StableDiffusion3ControlNetPipeline .from_pretrained (
7678 args .pretrained_model_name_or_path ,
77- controlnet = controlnet ,
79+ controlnet = None ,
7880 safety_checker = None ,
81+ transformer = None ,
7982 revision = args .revision ,
8083 variant = args .variant ,
8184 torch_dtype = weight_dtype ,
@@ -102,18 +105,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
102105 "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
103106 )
104107
108+ with torch .no_grad ():
109+ (
110+ prompt_embeds ,
111+ negative_prompt_embeds ,
112+ pooled_prompt_embeds ,
113+ negative_pooled_prompt_embeds ,
114+ ) = pipeline .encode_prompt (
115+ validation_prompts ,
116+ prompt_2 = None ,
117+ prompt_3 = None ,
118+ )
119+
120+ del pipeline
121+ gc .collect ()
122+ backend_empty_cache (accelerator .device .type )
123+
124+ pipeline = StableDiffusion3ControlNetPipeline .from_pretrained (
125+ args .pretrained_model_name_or_path ,
126+ controlnet = controlnet ,
127+ safety_checker = None ,
128+ text_encoder = None ,
129+ text_encoder_2 = None ,
130+ text_encoder_3 = None ,
131+ revision = args .revision ,
132+ variant = args .variant ,
133+ torch_dtype = weight_dtype ,
134+ )
135+ pipeline .enable_model_cpu_offload (device = accelerator .device .type )
136+ pipeline .set_progress_bar_config (disable = True )
137+
105138 image_logs = []
106139 inference_ctx = contextlib .nullcontext () if is_final_validation else torch .autocast (accelerator .device .type )
107140
108- for validation_prompt , validation_image in zip ( validation_prompts , validation_images ):
141+ for i , validation_image in enumerate ( validation_images ):
109142 validation_image = Image .open (validation_image ).convert ("RGB" )
143+ validation_prompt = validation_prompts [i ]
110144
111145 images = []
112146
113147 for _ in range (args .num_validation_images ):
114148 with inference_ctx :
115149 image = pipeline (
116- validation_prompt , control_image = validation_image , num_inference_steps = 20 , generator = generator
150+ prompt_embeds = prompt_embeds [i ].unsqueeze (0 ),
151+ negative_prompt_embeds = negative_prompt_embeds [i ].unsqueeze (0 ),
152+ pooled_prompt_embeds = pooled_prompt_embeds [i ].unsqueeze (0 ),
153+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds [i ].unsqueeze (0 ),
154+ control_image = validation_image ,
155+ num_inference_steps = 20 ,
156+ generator = generator ,
117157 ).images [0 ]
118158
119159 images .append (image )
@@ -655,6 +695,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
655695 dataset = load_dataset (
656696 args .train_data_dir ,
657697 cache_dir = args .cache_dir ,
698+ trust_remote_code = True ,
658699 )
659700 # See more about loading custom images at
660701 # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
0 commit comments