Skip to content

Commit 9cd3755

Browse files
yiyixuxuchristopher-beckhamsayakpaul
authored
flux controlnet fix (control_modes batch & others) (huggingface#9507)
* flux controlnet mode to take into account batch size * incorporate yiyixuxu's suggestions (cleaner logic) as well as clean up control mode handling for multi case * fix * fix use_guidance when controlnet is a multi and does not have config --------- Co-authored-by: Christopher Beckham <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 1c6ede9 commit 9cd3755

File tree

2 files changed

+37
-25
lines changed

2 files changed

+37
-25
lines changed

src/diffusers/models/controlnet_flux.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -502,16 +502,17 @@ def forward(
502502
control_block_samples = block_samples
503503
control_single_block_samples = single_block_samples
504504
else:
505-
control_block_samples = [
506-
control_block_sample + block_sample
507-
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
508-
]
509-
510-
control_single_block_samples = [
511-
control_single_block_sample + block_sample
512-
for control_single_block_sample, block_sample in zip(
513-
control_single_block_samples, single_block_samples
514-
)
515-
]
505+
if block_samples is not None and control_block_samples is not None:
506+
control_block_samples = [
507+
control_block_sample + block_sample
508+
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
509+
]
510+
if single_block_samples is not None and control_single_block_samples is not None:
511+
control_single_block_samples = [
512+
control_single_block_sample + block_sample
513+
for control_single_block_sample, block_sample in zip(
514+
control_single_block_samples, single_block_samples
515+
)
516+
]
516517

517518
return control_block_samples, control_single_block_samples

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -747,10 +747,12 @@ def __call__(
747747
width_control_image,
748748
)
749749

750-
# set control mode
750+
# Here we ensure that `control_mode` has the same length as the control_image.
751751
if control_mode is not None:
752+
if not isinstance(control_mode, int):
753+
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
752754
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
753-
control_mode = control_mode.reshape([-1, 1])
755+
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
754756

755757
elif isinstance(self.controlnet, FluxMultiControlNetModel):
756758
control_images = []
@@ -785,16 +787,22 @@ def __call__(
785787

786788
control_image = control_images
787789

790+
# Here we ensure that `control_mode` has the same length as the control_image.
791+
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
792+
raise ValueError(
793+
"For Multi-ControlNet, `control_mode` must be a list of the same "
794+
+ " length as the number of controlnets (control images) specified"
795+
)
796+
if not isinstance(control_mode, list):
797+
control_mode = [control_mode] * len(control_image)
788798
# set control mode
789-
control_mode_ = []
790-
if isinstance(control_mode, list):
791-
for cmode in control_mode:
792-
if cmode is None:
793-
control_mode_.append(-1)
794-
else:
795-
control_mode_.append(cmode)
796-
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
797-
control_mode = control_mode.reshape([-1, 1])
799+
control_modes = []
800+
for cmode in control_mode:
801+
if cmode is None:
802+
cmode = -1
803+
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
804+
control_modes.append(control_mode)
805+
control_mode = control_modes
798806

799807
# 4. Prepare latent variables
800808
num_channels_latents = self.transformer.config.in_channels // 4
@@ -840,9 +848,12 @@ def __call__(
840848
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
841849
timestep = t.expand(latents.shape[0]).to(latents.dtype)
842850

843-
guidance = (
844-
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
845-
)
851+
if isinstance(self.controlnet, FluxMultiControlNetModel):
852+
use_guidance = self.controlnet.nets[0].config.guidance_embeds
853+
else:
854+
use_guidance = self.controlnet.config.guidance_embeds
855+
856+
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
846857
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
847858

848859
# controlnet

0 commit comments

Comments
 (0)