@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
152152 text_encoder ([`T5EncoderModel`]):
153153 [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
154154 the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
155- transformer ([`WanVACETransformer3DModel`]):
156- Conditional Transformer to denoise the input latents.
157- transformer_2 ([`WanVACETransformer3DModel`], *optional*):
158- Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
159- `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
160- `transformer` is used.
161- scheduler ([`UniPCMultistepScheduler`]):
162- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
163155 vae ([`AutoencoderKLWan`]):
164156 Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
157+ scheduler ([`UniPCMultistepScheduler`]):
158+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
159+ transformer ([`WanVACETransformer3DModel`], *optional*):
160+ Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
161+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
162+ `transformer` or `transformer_2` must be provided.
163+ transformer_2 ([`WanVACETransformer3DModel`], *optional*):
164+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
165+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
166+ `transformer` or `transformer_2` must be provided.
165167 boundary_ratio (`float`, *optional*, defaults to `None`):
166168 Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
167169 The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
168170 `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
169- boundary_timestep. If `None`, only ` transformer` is used for the entire denoising process.
171+ boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
170172 """
171173
172- model_cpu_offload_seq = "text_encoder->transformer->vae"
174+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2-> vae"
173175 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
174- _optional_components = ["transformer_2" ]
176+ _optional_components = ["transformer" , " transformer_2" ]
175177
176178 def __init__ (
177179 self ,
178180 tokenizer : AutoTokenizer ,
179181 text_encoder : UMT5EncoderModel ,
180- transformer : WanVACETransformer3DModel ,
181182 vae : AutoencoderKLWan ,
182183 scheduler : FlowMatchEulerDiscreteScheduler ,
184+ transformer : WanVACETransformer3DModel = None ,
183185 transformer_2 : WanVACETransformer3DModel = None ,
184186 boundary_ratio : Optional [float ] = None ,
185187 ):
@@ -336,7 +338,15 @@ def check_inputs(
336338 reference_images = None ,
337339 guidance_scale_2 = None ,
338340 ):
339- base = self .vae_scale_factor_spatial * self .transformer .config .patch_size [1 ]
341+ if self .transformer is not None :
342+ base = self .vae_scale_factor_spatial * self .transformer .config .patch_size [1 ]
343+ elif self .transformer_2 is not None :
344+ base = self .vae_scale_factor_spatial * self .transformer_2 .config .patch_size [1 ]
345+ else :
346+ raise ValueError (
347+ "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
348+ )
349+
340350 if height % base != 0 or width % base != 0 :
341351 raise ValueError (f"`height` and `width` have to be divisible by { base } but are { height } and { width } ." )
342352
@@ -414,7 +424,11 @@ def preprocess_conditions(
414424 device : Optional [torch .device ] = None ,
415425 ):
416426 if video is not None :
417- base = self .vae_scale_factor_spatial * self .transformer .config .patch_size [1 ]
427+ base = self .vae_scale_factor_spatial * (
428+ self .transformer .config .patch_size [1 ]
429+ if self .transformer is not None
430+ else self .transformer_2 .config .patch_size [1 ]
431+ )
418432 video_height , video_width = self .video_processor .get_default_height_width (video [0 ])
419433
420434 if video_height * video_width > height * width :
@@ -589,7 +603,11 @@ def prepare_masks(
589603 "Generating with more than one video is not yet supported. This may be supported in the future."
590604 )
591605
592- transformer_patch_size = self .transformer .config .patch_size [1 ]
606+ transformer_patch_size = (
607+ self .transformer .config .patch_size [1 ]
608+ if self .transformer is not None
609+ else self .transformer_2 .config .patch_size [1 ]
610+ )
593611
594612 mask_list = []
595613 for mask_ , reference_images_batch in zip (mask , reference_images ):
@@ -844,20 +862,25 @@ def __call__(
844862 batch_size = prompt_embeds .shape [0 ]
845863
846864 vae_dtype = self .vae .dtype
847- transformer_dtype = self .transformer .dtype
865+ transformer_dtype = self .transformer .dtype if self . transformer is not None else self . transformer_2 . dtype
848866
867+ vace_layers = (
868+ self .transformer .config .vace_layers
869+ if self .transformer is not None
870+ else self .transformer_2 .config .vace_layers
871+ )
849872 if isinstance (conditioning_scale , (int , float )):
850- conditioning_scale = [conditioning_scale ] * len (self . transformer . config . vace_layers )
873+ conditioning_scale = [conditioning_scale ] * len (vace_layers )
851874 if isinstance (conditioning_scale , list ):
852- if len (conditioning_scale ) != len (self . transformer . config . vace_layers ):
875+ if len (conditioning_scale ) != len (vace_layers ):
853876 raise ValueError (
854- f"Length of `conditioning_scale` { len (conditioning_scale )} does not match number of layers { len (self . transformer . config . vace_layers )} ."
877+ f"Length of `conditioning_scale` { len (conditioning_scale )} does not match number of layers { len (vace_layers )} ."
855878 )
856879 conditioning_scale = torch .tensor (conditioning_scale )
857880 if isinstance (conditioning_scale , torch .Tensor ):
858- if conditioning_scale .size (0 ) != len (self . transformer . config . vace_layers ):
881+ if conditioning_scale .size (0 ) != len (vace_layers ):
859882 raise ValueError (
860- f"Length of `conditioning_scale` { conditioning_scale .size (0 )} does not match number of layers { len (self . transformer . config . vace_layers )} ."
883+ f"Length of `conditioning_scale` { conditioning_scale .size (0 )} does not match number of layers { len (vace_layers )} ."
861884 )
862885 conditioning_scale = conditioning_scale .to (device = device , dtype = transformer_dtype )
863886
@@ -900,7 +923,11 @@ def __call__(
900923 conditioning_latents = torch .cat ([conditioning_latents , mask ], dim = 1 )
901924 conditioning_latents = conditioning_latents .to (transformer_dtype )
902925
903- num_channels_latents = self .transformer .config .in_channels
926+ num_channels_latents = (
927+ self .transformer .config .in_channels
928+ if self .transformer is not None
929+ else self .transformer_2 .config .in_channels
930+ )
904931 latents = self .prepare_latents (
905932 batch_size * num_videos_per_prompt ,
906933 num_channels_latents ,
@@ -968,7 +995,7 @@ def __call__(
968995 attention_kwargs = attention_kwargs ,
969996 return_dict = False ,
970997 )[0 ]
971- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
998+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
972999
9731000 # compute the previous noisy sample x_t -> x_t-1
9741001 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments