Skip to content

Commit f1ba5df

Browse files
authored
Merge branch 'kohya-ss:sd3' into sd3
2 parents b811bc8 + 75933d7 commit f1ba5df

16 files changed

+871
-829
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ The command to install PyTorch is as follows:
1414

1515
### Recent Updates
1616

17+
Mar 6, 2025:
18+
19+
- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960)
20+
21+
Feb 26, 2025:
22+
23+
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
24+
- The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values.
25+
1726
Jan 25, 2025:
1827

1928
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!

flux_train_network.py

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ def __init__(self):
3636
self.is_schnell: Optional[bool] = None
3737
self.is_swapping_blocks: bool = False
3838

39-
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
39+
def assert_extra_args(
40+
self,
41+
args,
42+
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
43+
val_dataset_group: Optional[train_util.DatasetGroup],
44+
):
4045
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
4146
# sdxl_train_util.verify_sdxl_training_args(args)
4247

@@ -323,7 +328,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) ->
323328
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
324329
return noise_scheduler
325330

326-
def encode_images_to_latents(self, args, accelerator, vae, images):
331+
def encode_images_to_latents(self, args, vae, images):
327332
return vae.encode(images)
328333

329334
def shift_scale_latents(self, args, latents):
@@ -341,7 +346,7 @@ def get_noise_pred_and_target(
341346
network,
342347
weight_dtype,
343348
train_unet,
344-
is_train=True
349+
is_train=True,
345350
):
346351
# Sample noise that we'll add to the latents
347352
noise = torch.randn_like(latents)
@@ -376,8 +381,7 @@ def get_noise_pred_and_target(
376381
t5_attn_mask = None
377382

378383
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
379-
# if not args.split_mode:
380-
# normal forward
384+
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
381385
with torch.set_grad_enabled(is_train), accelerator.autocast():
382386
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
383387
model_pred = unet(
@@ -390,44 +394,6 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
390394
guidance=guidance_vec,
391395
txt_attention_mask=t5_attn_mask,
392396
)
393-
"""
394-
else:
395-
# split forward to reduce memory usage
396-
assert network.train_blocks == "single", "train_blocks must be single for split mode"
397-
with accelerator.autocast():
398-
# move flux lower to cpu, and then move flux upper to gpu
399-
unet.to("cpu")
400-
clean_memory_on_device(accelerator.device)
401-
self.flux_upper.to(accelerator.device)
402-
403-
# upper model does not require grad
404-
with torch.no_grad():
405-
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
406-
img=packed_noisy_model_input,
407-
img_ids=img_ids,
408-
txt=t5_out,
409-
txt_ids=txt_ids,
410-
y=l_pooled,
411-
timesteps=timesteps / 1000,
412-
guidance=guidance_vec,
413-
txt_attention_mask=t5_attn_mask,
414-
)
415-
416-
# move flux upper back to cpu, and then move flux lower to gpu
417-
self.flux_upper.to("cpu")
418-
clean_memory_on_device(accelerator.device)
419-
unet.to(accelerator.device)
420-
421-
# lower model requires grad
422-
intermediate_img.requires_grad_(True)
423-
intermediate_txt.requires_grad_(True)
424-
vec.requires_grad_(True)
425-
pe.requires_grad_(True)
426-
427-
with torch.set_grad_enabled(is_train and train_unet):
428-
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
429-
"""
430-
431397
return model_pred
432398

433399
model_pred = call_dit(
@@ -546,6 +512,11 @@ def forward(hidden_states):
546512
text_encoder.to(te_weight_dtype) # fp8
547513
prepare_fp8(text_encoder, weight_dtype)
548514

515+
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
516+
if self.is_swapping_blocks:
517+
# prepare for next forward: because backward pass is not called, we need to prepare it here
518+
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
519+
549520
def prepare_unet_with_accelerator(
550521
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
551522
) -> torch.nn.Module:

library/device_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
import gc
33

44
import torch
5+
try:
6+
# intel gpu support for pytorch older than 2.5
7+
# ipex is not needed after pytorch 2.5
8+
import intel_extension_for_pytorch as ipex # noqa
9+
except Exception:
10+
pass
11+
512

613
try:
714
HAS_CUDA = torch.cuda.is_available()
@@ -14,8 +21,6 @@
1421
HAS_MPS = False
1522

1623
try:
17-
import intel_extension_for_pytorch as ipex # noqa
18-
1924
HAS_XPU = torch.xpu.is_available()
2025
except Exception:
2126
HAS_XPU = False
@@ -69,7 +74,7 @@ def init_ipex():
6974
7075
This function should run right after importing torch and before doing anything else.
7176
72-
If IPEX is not available, this function does nothing.
77+
If xpu is not available, this function does nothing.
7378
"""
7479
try:
7580
if HAS_XPU:

0 commit comments

Comments
 (0)