Skip to content

Commit ed8bda9

Browse files
authored
Merge branch 'main' into cogvideox-5b-i2v
2 parents 6f313e8 + 2454b98 commit ed8bda9

File tree

9 files changed

+118
-23
lines changed

9 files changed

+118
-23
lines changed

examples/dreambooth/README_flux.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,12 @@ Instead, only a subset of these activations (the checkpoints) are stored and the
221221
### 8-bit-Adam Optimizer
222222
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
223223
Make sure to install `bitsandbytes` if you want to do so.
224-
### latent caching
224+
### Latent caching
225225
When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory.
226-
to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents`
226+
to enable `latent_caching` simply pass `--cache_latents`.
227+
### Precision of saved LoRA layers
228+
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
229+
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
230+
227231
## Other notes
228232
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️

examples/dreambooth/test_dreambooth_lora_flux.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,39 @@ def test_dreambooth_lora_text_encoder_flux(self):
103103
)
104104
self.assertTrue(starts_with_expected_prefix)
105105

106+
def test_dreambooth_lora_latent_caching(self):
107+
with tempfile.TemporaryDirectory() as tmpdir:
108+
test_args = f"""
109+
{self.script_path}
110+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
111+
--instance_data_dir {self.instance_data_dir}
112+
--instance_prompt {self.instance_prompt}
113+
--resolution 64
114+
--train_batch_size 1
115+
--gradient_accumulation_steps 1
116+
--max_train_steps 2
117+
--cache_latents
118+
--learning_rate 5.0e-04
119+
--scale_lr
120+
--lr_scheduler constant
121+
--lr_warmup_steps 0
122+
--output_dir {tmpdir}
123+
""".split()
124+
125+
run_command(self._launch_args + test_args)
126+
# save_pretrained smoke test
127+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
128+
129+
# make sure the state_dict has the correct naming in the parameters.
130+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
131+
is_lora = all("lora" in k for k in lora_state_dict.keys())
132+
self.assertTrue(is_lora)
133+
134+
# when not training the text encoder, all the parameters in the state dict should start
135+
# with `"transformer"` in their names.
136+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
137+
self.assertTrue(starts_with_transformer)
138+
106139
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
107140
with tempfile.TemporaryDirectory() as tmpdir:
108141
test_args = f"""

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,14 @@ def log_validation(
154154
accelerator,
155155
pipeline_args,
156156
epoch,
157+
torch_dtype,
157158
is_final_validation=False,
158159
):
159160
logger.info(
160161
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
161162
f" {args.validation_prompt}."
162163
)
163-
pipeline = pipeline.to(accelerator.device)
164+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
164165
pipeline.set_progress_bar_config(disable=True)
165166

166167
# run inference
@@ -1717,6 +1718,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17171718
accelerator=accelerator,
17181719
pipeline_args=pipeline_args,
17191720
epoch=epoch,
1721+
torch_dtype=weight_dtype,
17201722
)
17211723
if not args.train_text_encoder:
17221724
del text_encoder_one, text_encoder_two
@@ -1761,6 +1763,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17611763
pipeline_args=pipeline_args,
17621764
epoch=epoch,
17631765
is_final_validation=True,
1766+
torch_dtype=weight_dtype,
17641767
)
17651768

17661769
if args.push_to_hub:

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def log_validation(
122122
accelerator,
123123
pipeline_args,
124124
epoch,
125+
torch_dtype,
125126
is_final_validation=False,
126127
):
127128
logger.info(
@@ -141,7 +142,7 @@ def log_validation(
141142

142143
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
143144

144-
pipeline = pipeline.to(accelerator.device)
145+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
145146
pipeline.set_progress_bar_config(disable=True)
146147

147148
# run inference
@@ -1360,6 +1361,7 @@ def compute_text_embeddings(prompt):
13601361
accelerator,
13611362
pipeline_args,
13621363
epoch,
1364+
torch_dtype=weight_dtype,
13631365
)
13641366

13651367
# Save the lora layers
@@ -1402,6 +1404,7 @@ def compute_text_embeddings(prompt):
14021404
pipeline_args,
14031405
epoch,
14041406
is_final_validation=True,
1407+
torch_dtype=weight_dtype,
14051408
)
14061409

14071410
if args.push_to_hub:

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import argparse
1717
import copy
18-
import gc
1918
import itertools
2019
import logging
2120
import math
@@ -56,6 +55,7 @@
5655
from diffusers.training_utils import (
5756
_set_state_dict_into_text_encoder,
5857
cast_training_params,
58+
clear_objs_and_retain_memory,
5959
compute_density_for_timestep_sampling,
6060
compute_loss_weighting_for_sd3,
6161
)
@@ -170,13 +170,14 @@ def log_validation(
170170
accelerator,
171171
pipeline_args,
172172
epoch,
173+
torch_dtype,
173174
is_final_validation=False,
174175
):
175176
logger.info(
176177
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
177178
f" {args.validation_prompt}."
178179
)
179-
pipeline = pipeline.to(accelerator.device)
180+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
180181
pipeline.set_progress_bar_config(disable=True)
181182

182183
# run inference
@@ -599,6 +600,12 @@ def parse_args(input_args=None):
599600
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
600601
),
601602
)
603+
parser.add_argument(
604+
"--cache_latents",
605+
action="store_true",
606+
default=False,
607+
help="Cache the VAE latents",
608+
)
602609
parser.add_argument(
603610
"--report_to",
604611
type=str,
@@ -619,6 +626,15 @@ def parse_args(input_args=None):
619626
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
620627
),
621628
)
629+
parser.add_argument(
630+
"--upcast_before_saving",
631+
action="store_true",
632+
default=False,
633+
help=(
634+
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
635+
"Defaults to precision dtype used for training to save memory"
636+
),
637+
)
622638
parser.add_argument(
623639
"--prior_generation_precision",
624640
type=str,
@@ -1421,12 +1437,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14211437

14221438
# Clear the memory here
14231439
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1424-
del tokenizers, text_encoders
1425-
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
1426-
del text_encoder_one, text_encoder_two
1427-
gc.collect()
1428-
if torch.cuda.is_available():
1429-
torch.cuda.empty_cache()
1440+
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
14301441

14311442
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14321443
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1456,6 +1467,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14561467
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
14571468
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
14581469

1470+
vae_config_shift_factor = vae.config.shift_factor
1471+
vae_config_scaling_factor = vae.config.scaling_factor
1472+
vae_config_block_out_channels = vae.config.block_out_channels
1473+
if args.cache_latents:
1474+
latents_cache = []
1475+
for batch in tqdm(train_dataloader, desc="Caching latents"):
1476+
with torch.no_grad():
1477+
batch["pixel_values"] = batch["pixel_values"].to(
1478+
accelerator.device, non_blocking=True, dtype=weight_dtype
1479+
)
1480+
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
1481+
1482+
if args.validation_prompt is None:
1483+
clear_objs_and_retain_memory([vae])
1484+
14591485
# Scheduler and math around the number of training steps.
14601486
overrode_max_train_steps = False
14611487
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1578,7 +1604,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15781604
if args.train_text_encoder:
15791605
models_to_accumulate.extend([text_encoder_one])
15801606
with accelerator.accumulate(models_to_accumulate):
1581-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
15821607
prompts = batch["prompts"]
15831608

15841609
# encode batch prompts when custom prompts are provided for each image -
@@ -1612,11 +1637,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16121637
)
16131638

16141639
# Convert images to latent space
1615-
model_input = vae.encode(pixel_values).latent_dist.sample()
1616-
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
1640+
if args.cache_latents:
1641+
model_input = latents_cache[step].sample()
1642+
else:
1643+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1644+
model_input = vae.encode(pixel_values).latent_dist.sample()
1645+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
16171646
model_input = model_input.to(dtype=weight_dtype)
16181647

1619-
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
1648+
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
16201649

16211650
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
16221651
model_input.shape[0],
@@ -1785,17 +1814,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17851814
accelerator=accelerator,
17861815
pipeline_args=pipeline_args,
17871816
epoch=epoch,
1817+
torch_dtype=weight_dtype,
17881818
)
17891819
if not args.train_text_encoder:
1790-
del text_encoder_one, text_encoder_two
1791-
torch.cuda.empty_cache()
1792-
gc.collect()
1820+
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
17931821

17941822
# Save the lora layers
17951823
accelerator.wait_for_everyone()
17961824
if accelerator.is_main_process:
17971825
transformer = unwrap_model(transformer)
1798-
transformer = transformer.to(torch.float32)
1826+
if args.upcast_before_saving:
1827+
transformer.to(torch.float32)
1828+
else:
1829+
transformer = transformer.to(weight_dtype)
17991830
transformer_lora_layers = get_peft_model_state_dict(transformer)
18001831

18011832
if args.train_text_encoder:
@@ -1832,6 +1863,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18321863
pipeline_args=pipeline_args,
18331864
epoch=epoch,
18341865
is_final_validation=True,
1866+
torch_dtype=weight_dtype,
18351867
)
18361868

18371869
if args.push_to_hub:

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,14 @@ def log_validation(
179179
accelerator,
180180
pipeline_args,
181181
epoch,
182+
torch_dtype,
182183
is_final_validation=False,
183184
):
184185
logger.info(
185186
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
186187
f" {args.validation_prompt}."
187188
)
188-
pipeline = pipeline.to(accelerator.device)
189+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
189190
pipeline.set_progress_bar_config(disable=True)
190191

191192
# run inference
@@ -1788,6 +1789,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17881789
accelerator=accelerator,
17891790
pipeline_args=pipeline_args,
17901791
epoch=epoch,
1792+
torch_dtype=weight_dtype,
17911793
)
17921794
objs = []
17931795
if not args.train_text_encoder:
@@ -1840,6 +1842,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18401842
pipeline_args=pipeline_args,
18411843
epoch=epoch,
18421844
is_final_validation=True,
1845+
torch_dtype=weight_dtype,
18431846
)
18441847

18451848
if args.push_to_hub:

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def log_validation(
180180
accelerator,
181181
pipeline_args,
182182
epoch,
183+
torch_dtype,
183184
is_final_validation=False,
184185
):
185186
logger.info(
@@ -201,7 +202,7 @@ def log_validation(
201202

202203
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
203204

204-
pipeline = pipeline.to(accelerator.device)
205+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
205206
pipeline.set_progress_bar_config(disable=True)
206207

207208
# run inference
@@ -1890,6 +1891,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18901891
accelerator,
18911892
pipeline_args,
18921893
epoch,
1894+
torch_dtype=weight_dtype,
18931895
)
18941896

18951897
# Save the lora layers
@@ -1955,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19551957
pipeline_args,
19561958
epoch,
19571959
is_final_validation=True,
1960+
torch_dtype=weight_dtype,
19581961
)
19591962

19601963
if args.push_to_hub:

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,14 @@ def log_validation(
157157
accelerator,
158158
pipeline_args,
159159
epoch,
160+
torch_dtype,
160161
is_final_validation=False,
161162
):
162163
logger.info(
163164
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
164165
f" {args.validation_prompt}."
165166
)
166-
pipeline = pipeline.to(accelerator.device)
167+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
167168
pipeline.set_progress_bar_config(disable=True)
168169

169170
# run inference
@@ -1725,6 +1726,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17251726
accelerator=accelerator,
17261727
pipeline_args=pipeline_args,
17271728
epoch=epoch,
1729+
torch_dtype=weight_dtype,
17281730
)
17291731
if not args.train_text_encoder:
17301732
del text_encoder_one, text_encoder_two, text_encoder_three
@@ -1775,6 +1777,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17751777
pipeline_args=pipeline_args,
17761778
epoch=epoch,
17771779
is_final_validation=True,
1780+
torch_dtype=weight_dtype,
17781781
)
17791782

17801783
if args.push_to_hub:

0 commit comments

Comments
 (0)