@@ -747,10 +747,12 @@ def __call__(
747747 width_control_image ,
748748 )
749749
750- # set control mode
750+ # Here we ensure that `control_mode` has the same length as the control_image.
751751 if control_mode is not None :
752+ if not isinstance (control_mode , int ):
753+ raise ValueError (" For `FluxControlNet`, `control_mode` should be an `int` or `None`" )
752754 control_mode = torch .tensor (control_mode ).to (device , dtype = torch .long )
753- control_mode = control_mode .reshape ([ - 1 , 1 ] )
755+ control_mode = control_mode .view ( - 1 , 1 ). expand ( control_image . shape [ 0 ], 1 )
754756
755757 elif isinstance (self .controlnet , FluxMultiControlNetModel ):
756758 control_images = []
@@ -785,16 +787,22 @@ def __call__(
785787
786788 control_image = control_images
787789
790+ # Here we ensure that `control_mode` has the same length as the control_image.
791+ if isinstance (control_mode , list ) and len (control_mode ) != len (control_image ):
792+ raise ValueError (
793+ "For Multi-ControlNet, `control_mode` must be a list of the same "
794+ + " length as the number of controlnets (control images) specified"
795+ )
796+ if not isinstance (control_mode , list ):
797+ control_mode = [control_mode ] * len (control_image )
788798 # set control mode
789- control_mode_ = []
790- if isinstance (control_mode , list ):
791- for cmode in control_mode :
792- if cmode is None :
793- control_mode_ .append (- 1 )
794- else :
795- control_mode_ .append (cmode )
796- control_mode = torch .tensor (control_mode_ ).to (device , dtype = torch .long )
797- control_mode = control_mode .reshape ([- 1 , 1 ])
799+ control_modes = []
800+ for cmode in control_mode :
801+ if cmode is None :
802+ cmode = - 1
803+ control_mode = torch .tensor (cmode ).expand (control_images [0 ].shape [0 ]).to (device , dtype = torch .long )
804+ control_modes .append (control_mode )
805+ control_mode = control_modes
798806
799807 # 4. Prepare latent variables
800808 num_channels_latents = self .transformer .config .in_channels // 4
@@ -840,9 +848,12 @@ def __call__(
840848 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
841849 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
842850
843- guidance = (
844- torch .tensor ([guidance_scale ], device = device ) if self .controlnet .config .guidance_embeds else None
845- )
851+ if isinstance (self .controlnet , FluxMultiControlNetModel ):
852+ use_guidance = self .controlnet .nets [0 ].config .guidance_embeds
853+ else :
854+ use_guidance = self .controlnet .config .guidance_embeds
855+
856+ guidance = torch .tensor ([guidance_scale ], device = device ) if use_guidance else None
846857 guidance = guidance .expand (latents .shape [0 ]) if guidance is not None else None
847858
848859 # controlnet
0 commit comments