Skip to content

Commit 37e3603

Browse files
[Flux Dreambooth lora] add latent caching (huggingface#9160)
* add ostris trainer to README & add cache latents of vae * add ostris trainer to README & add cache latents of vae * style * readme * add test for latent caching * add ostris noise scheduler https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 * style * fix import * style * fix tests * style * --change upcasting of transformer? * update readme according to main * keep only latent caching * add configurable param for final saving of trained layers- --upcast_before_saving * style * Update examples/dreambooth/README_flux.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/dreambooth/README_flux.md Co-authored-by: Sayak Paul <[email protected]> * use clear_objs_and_retain_memory from utilities * style --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent e2ead7c commit 37e3603

File tree

3 files changed

+83
-17
lines changed

3 files changed

+83
-17
lines changed

examples/dreambooth/README_flux.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,12 @@ Instead, only a subset of these activations (the checkpoints) are stored and the
221221
### 8-bit-Adam Optimizer
222222
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
223223
Make sure to install `bitsandbytes` if you want to do so.
224-
### latent caching
224+
### Latent caching
225225
When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory.
226-
to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents`
226+
to enable `latent_caching` simply pass `--cache_latents`.
227+
### Precision of saved LoRA layers
228+
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
229+
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
230+
227231
## Other notes
228232
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️

examples/dreambooth/test_dreambooth_lora_flux.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,39 @@ def test_dreambooth_lora_text_encoder_flux(self):
103103
)
104104
self.assertTrue(starts_with_expected_prefix)
105105

106+
def test_dreambooth_lora_latent_caching(self):
107+
with tempfile.TemporaryDirectory() as tmpdir:
108+
test_args = f"""
109+
{self.script_path}
110+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
111+
--instance_data_dir {self.instance_data_dir}
112+
--instance_prompt {self.instance_prompt}
113+
--resolution 64
114+
--train_batch_size 1
115+
--gradient_accumulation_steps 1
116+
--max_train_steps 2
117+
--cache_latents
118+
--learning_rate 5.0e-04
119+
--scale_lr
120+
--lr_scheduler constant
121+
--lr_warmup_steps 0
122+
--output_dir {tmpdir}
123+
""".split()
124+
125+
run_command(self._launch_args + test_args)
126+
# save_pretrained smoke test
127+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
128+
129+
# make sure the state_dict has the correct naming in the parameters.
130+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
131+
is_lora = all("lora" in k for k in lora_state_dict.keys())
132+
self.assertTrue(is_lora)
133+
134+
# when not training the text encoder, all the parameters in the state dict should start
135+
# with `"transformer"` in their names.
136+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
137+
self.assertTrue(starts_with_transformer)
138+
106139
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
107140
with tempfile.TemporaryDirectory() as tmpdir:
108141
test_args = f"""

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import argparse
1717
import copy
18-
import gc
1918
import itertools
2019
import logging
2120
import math
@@ -56,6 +55,7 @@
5655
from diffusers.training_utils import (
5756
_set_state_dict_into_text_encoder,
5857
cast_training_params,
58+
clear_objs_and_retain_memory,
5959
compute_density_for_timestep_sampling,
6060
compute_loss_weighting_for_sd3,
6161
)
@@ -600,6 +600,12 @@ def parse_args(input_args=None):
600600
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
601601
),
602602
)
603+
parser.add_argument(
604+
"--cache_latents",
605+
action="store_true",
606+
default=False,
607+
help="Cache the VAE latents",
608+
)
603609
parser.add_argument(
604610
"--report_to",
605611
type=str,
@@ -620,6 +626,15 @@ def parse_args(input_args=None):
620626
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
621627
),
622628
)
629+
parser.add_argument(
630+
"--upcast_before_saving",
631+
action="store_true",
632+
default=False,
633+
help=(
634+
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
635+
"Defaults to precision dtype used for training to save memory"
636+
),
637+
)
623638
parser.add_argument(
624639
"--prior_generation_precision",
625640
type=str,
@@ -1422,12 +1437,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14221437

14231438
# Clear the memory here
14241439
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1425-
del tokenizers, text_encoders
1426-
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
1427-
del text_encoder_one, text_encoder_two
1428-
gc.collect()
1429-
if torch.cuda.is_available():
1430-
torch.cuda.empty_cache()
1440+
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
14311441

14321442
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14331443
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1457,6 +1467,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14571467
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
14581468
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
14591469

1470+
vae_config_shift_factor = vae.config.shift_factor
1471+
vae_config_scaling_factor = vae.config.scaling_factor
1472+
vae_config_block_out_channels = vae.config.block_out_channels
1473+
if args.cache_latents:
1474+
latents_cache = []
1475+
for batch in tqdm(train_dataloader, desc="Caching latents"):
1476+
with torch.no_grad():
1477+
batch["pixel_values"] = batch["pixel_values"].to(
1478+
accelerator.device, non_blocking=True, dtype=weight_dtype
1479+
)
1480+
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
1481+
1482+
if args.validation_prompt is None:
1483+
clear_objs_and_retain_memory([vae])
1484+
14601485
# Scheduler and math around the number of training steps.
14611486
overrode_max_train_steps = False
14621487
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1579,7 +1604,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15791604
if args.train_text_encoder:
15801605
models_to_accumulate.extend([text_encoder_one])
15811606
with accelerator.accumulate(models_to_accumulate):
1582-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
15831607
prompts = batch["prompts"]
15841608

15851609
# encode batch prompts when custom prompts are provided for each image -
@@ -1613,11 +1637,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16131637
)
16141638

16151639
# Convert images to latent space
1616-
model_input = vae.encode(pixel_values).latent_dist.sample()
1617-
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
1640+
if args.cache_latents:
1641+
model_input = latents_cache[step].sample()
1642+
else:
1643+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1644+
model_input = vae.encode(pixel_values).latent_dist.sample()
1645+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
16181646
model_input = model_input.to(dtype=weight_dtype)
16191647

1620-
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
1648+
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
16211649

16221650
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
16231651
model_input.shape[0],
@@ -1789,15 +1817,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17891817
torch_dtype=weight_dtype,
17901818
)
17911819
if not args.train_text_encoder:
1792-
del text_encoder_one, text_encoder_two
1793-
torch.cuda.empty_cache()
1794-
gc.collect()
1820+
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
17951821

17961822
# Save the lora layers
17971823
accelerator.wait_for_everyone()
17981824
if accelerator.is_main_process:
17991825
transformer = unwrap_model(transformer)
1800-
transformer = transformer.to(torch.float32)
1826+
if args.upcast_before_saving:
1827+
transformer.to(torch.float32)
1828+
else:
1829+
transformer = transformer.to(weight_dtype)
18011830
transformer_lora_layers = get_peft_model_state_dict(transformer)
18021831

18031832
if args.train_text_encoder:

0 commit comments

Comments
 (0)