@@ -194,6 +194,19 @@ def __init__(
194194 super ().__init__ ()
195195 if isinstance (controlnet , (list , tuple )):
196196 controlnet = SD3MultiControlNetModel (controlnet )
197+ if isinstance (controlnet , SD3MultiControlNetModel ):
198+ for controlnet_model in controlnet .nets :
199+ # for SD3.5 8b controlnet, it shares the pos_embed with the transformer
200+ if (
201+ hasattr (controlnet_model .config , "use_pos_embed" )
202+ and controlnet_model .config .use_pos_embed is False
203+ ):
204+ pos_embed = controlnet_model ._get_pos_embed_from_transformer (transformer )
205+ controlnet_model .pos_embed = pos_embed .to (controlnet_model .dtype ).to (controlnet_model .device )
206+ elif isinstance (controlnet , SD3ControlNetModel ):
207+ if hasattr (controlnet .config , "use_pos_embed" ) and controlnet .config .use_pos_embed is False :
208+ pos_embed = controlnet ._get_pos_embed_from_transformer (transformer )
209+ controlnet .pos_embed = pos_embed .to (controlnet .dtype ).to (controlnet .device )
197210
198211 self .register_modules (
199212 vae = vae ,
@@ -1042,15 +1055,9 @@ def __call__(
10421055 controlnet_cond_scale = controlnet_cond_scale [0 ]
10431056 cond_scale = controlnet_cond_scale * controlnet_keep [i ]
10441057
1045- if controlnet_config .use_pos_embed is False :
1046- # sd35 (offical) 8b controlnet
1047- controlnet_model_input = self .transformer .pos_embed (latent_model_input )
1048- else :
1049- controlnet_model_input = latent_model_input
1050-
10511058 # controlnet(s) inference
10521059 control_block_samples = self .controlnet (
1053- hidden_states = controlnet_model_input ,
1060+ hidden_states = latent_model_input ,
10541061 timestep = timestep ,
10551062 encoder_hidden_states = controlnet_encoder_hidden_states ,
10561063 pooled_projections = controlnet_pooled_projections ,
0 commit comments