From f6703efc65ca99a440b2f9269f4bc2ef564b7eea Mon Sep 17 00:00:00 2001 From: Jefsky Date: Sun, 31 May 2026 00:09:53 +0800 Subject: [PATCH] Fix incorrect batch temporal IDs for cond_model_input in dreambooth flux2 img2img The _prepare_image_ids method assigns different temporal embeddings (T=10, T=20, T=30...) to distinguish multiple reference images within a single sample. However, in the training script, cond_model_input has shape (B, C, H, W) where each batch element is an independent training sample with only one conditional image. The previous implementation split the batch into individual samples, producing incorrect cross-sample temporal offsets (sample 0 -> T=10, sample 1 -> T=20, etc.). Fix: generate temporal IDs for one sample and expand across the batch dimension, so all samples use the same temporal ID (T=10). --- .../train_dreambooth_lora_flux2_img2img.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 477697fadb64..a5c968670417 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1717,12 +1717,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] - cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( - device=cond_model_input.device - ) - cond_model_input_ids = cond_model_input_ids.view( - cond_model_input.shape[0], -1, model_input_ids.shape[-1] + # Each batch element is an independent training sample with a single + # conditional image. Generate temporal IDs for one sample and expand + # across the batch, avoiding incorrect cross-sample temporal offsets. + cond_model_input_ids = Flux2Pipeline._prepare_image_ids( + [cond_model_input[0:1]] + ).to(device=cond_model_input.device) + cond_model_input_ids = cond_model_input_ids.expand( + cond_model_input.shape[0], -1, -1 ) # Sample noise that we'll add to the latents