Skip to content

Commit 96a9097

Browse files
Adenialzzsayakpaul
andauthored
Add offload option in flux-control training (huggingface#10225)
* Add offload option in flux-control training * Update examples/flux-control/train_control_flux.py Co-authored-by: Sayak Paul <[email protected]> * modify help message * fix format --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent a5f35ee commit 96a9097

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

examples/flux-control/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ accelerate launch train_control_lora_flux.py \
3636
--max_train_steps=5000 \
3737
--validation_image="openpose.png" \
3838
--validation_prompt="A couple, 4k photo, highly detailed" \
39+
--offload \
3940
--seed="0" \
4041
--push_to_hub
4142
```
@@ -154,6 +155,7 @@ accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
154155
--validation_steps=200 \
155156
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
156157
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
158+
--offload \
157159
--seed="0" \
158160
--push_to_hub
159161
```

examples/flux-control/train_control_flux.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,11 @@ def parse_args(input_args=None):
541541
default=1.29,
542542
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
543543
)
544+
parser.add_argument(
545+
"--offload",
546+
action="store_true",
547+
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
548+
)
544549

545550
if input_args is not None:
546551
args = parser.parse_args(input_args)
@@ -999,8 +1004,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
9991004
control_latents = encode_images(
10001005
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
10011006
)
1002-
# offload vae to CPU.
1003-
vae.cpu()
1007+
if args.offload:
1008+
# offload vae to CPU.
1009+
vae.cpu()
10041010

10051011
# Sample a random timestep for each image
10061012
# for weighting schemes where we sample timesteps non-uniformly
@@ -1064,7 +1070,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10641070
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
10651071
prompt_embeds.zero_()
10661072
pooled_prompt_embeds.zero_()
1067-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1073+
if args.offload:
1074+
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
10681075

10691076
# Predict.
10701077
model_pred = flux_transformer(

examples/flux-control/train_control_lora_flux.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,11 @@ def parse_args(input_args=None):
573573
default=1.29,
574574
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
575575
)
576+
parser.add_argument(
577+
"--offload",
578+
action="store_true",
579+
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
580+
)
576581

577582
if input_args is not None:
578583
args = parser.parse_args(input_args)
@@ -1140,8 +1145,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11401145
control_latents = encode_images(
11411146
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
11421147
)
1143-
# offload vae to CPU.
1144-
vae.cpu()
1148+
1149+
if args.offload:
1150+
# offload vae to CPU.
1151+
vae.cpu()
11451152

11461153
# Sample a random timestep for each image
11471154
# for weighting schemes where we sample timesteps non-uniformly
@@ -1205,7 +1212,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12051212
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
12061213
prompt_embeds.zero_()
12071214
pooled_prompt_embeds.zero_()
1208-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1215+
if args.offload:
1216+
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
12091217

12101218
# Predict.
12111219
model_pred = flux_transformer(

0 commit comments

Comments
 (0)