@@ -338,7 +338,7 @@ def encode_prompt(
338338
339339 def prepare_latents (
340340 self ,
341- image : Optional [ torch .Tensor ] = None ,
341+ image : torch .Tensor ,
342342 batch_size : int = 1 ,
343343 num_channels_latents : int = 16 ,
344344 num_frames : int = 13 ,
@@ -363,47 +363,46 @@ def prepare_latents(
363363 f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
364364 f" size of { batch_size } . Make sure the batch size matches the length of the generators."
365365 )
366+
367+ assert image .ndim == 4
368+ image = image .unsqueeze (2 ) # [B, C, F, H, W]
369+ print (image .shape )
366370
367- if latents is None :
368- assert image .ndim == 4
369- image = image .unsqueeze (2 ) # [B, C, F, H, W]
370- print (image .shape )
371-
372- if isinstance (generator , list ):
373- if len (generator ) != batch_size :
374- raise ValueError (
375- f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
376- f" size of { batch_size } . Make sure the batch size matches the length of the generators."
377- )
371+ if isinstance (generator , list ):
372+ if len (generator ) != batch_size :
373+ raise ValueError (
374+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
375+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
376+ )
378377
379- init_latents = [
380- retrieve_latents (self .vae .encode (image [i ].unsqueeze (0 )), generator [i ]) for i in range (batch_size )
381- ]
382- else :
383- init_latents = [retrieve_latents (self .vae .encode (img .unsqueeze (0 )), generator ) for img in image ]
384-
385- init_latents = torch .cat (init_latents , dim = 0 ).to (dtype ).permute (0 , 2 , 1 , 3 , 4 ) # [B, F, C, H, W]
386- init_latents = self .vae .config .scaling_factor * init_latents
387-
388- padding_shape = (
389- batch_size ,
390- num_frames - 1 ,
391- num_channels_latents ,
392- height // self .vae_scale_factor_spatial ,
393- width // self .vae_scale_factor_spatial ,
394- )
395- latent_padding = torch .zeros (padding_shape , device = device , dtype = dtype )
396- print (init_latents .shape , latent_padding .shape )
397- init_latents = torch .cat ([init_latents , latent_padding ], dim = 1 )
378+ image_latents = [
379+ retrieve_latents (self .vae .encode (image [i ].unsqueeze (0 )), generator [i ]) for i in range (batch_size )
380+ ]
381+ else :
382+ image_latents = [retrieve_latents (self .vae .encode (img .unsqueeze (0 )), generator ) for img in image ]
398383
399- noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
400- latents = torch .cat ([noise , init_latents ], dim = 2 )
384+ image_latents = torch .cat (image_latents , dim = 0 ).to (dtype ).permute (0 , 2 , 1 , 3 , 4 ) # [B, F, C, H, W]
385+ image_latents = self .vae .config .scaling_factor * image_latents
386+
387+ padding_shape = (
388+ batch_size ,
389+ num_frames - 1 ,
390+ num_channels_latents ,
391+ height // self .vae_scale_factor_spatial ,
392+ width // self .vae_scale_factor_spatial ,
393+ )
394+ latent_padding = torch .zeros (padding_shape , device = device , dtype = dtype )
395+ print (image_latents .shape , latent_padding .shape )
396+ image_latents = torch .cat ([image_latents , latent_padding ], dim = 1 )
397+
398+ if latents is None :
399+ latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
401400 else :
402401 latents = latents .to (device )
403402
404403 # scale the initial noise by the standard deviation required by the scheduler
405404 latents = latents * self .scheduler .init_noise_sigma
406- return latents
405+ return latents , image_latents
407406
408407 # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
409408 def decode_latents (self , latents : torch .Tensor ) -> torch .Tensor :
@@ -446,7 +445,6 @@ def check_inputs(
446445 prompt ,
447446 height ,
448447 width ,
449- strength ,
450448 negative_prompt ,
451449 callback_on_step_end_tensor_inputs ,
452450 video = None ,
@@ -457,9 +455,6 @@ def check_inputs(
457455 if height % 8 != 0 or width % 8 != 0 :
458456 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
459457
460- if strength < 0 or strength > 1 :
461- raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
462-
463458 if callback_on_step_end_tensor_inputs is not None and not all (
464459 k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
465460 ):
@@ -567,12 +562,10 @@ def __call__(
567562 num_frames : int = 49 ,
568563 num_inference_steps : int = 50 ,
569564 timesteps : Optional [List [int ]] = None ,
570- strength : float = 0.8 ,
571565 guidance_scale : float = 6 ,
572566 use_dynamic_cfg : bool = False ,
573567 num_videos_per_prompt : int = 1 ,
574568 eta : float = 0.0 ,
575- noise_aug_strength : float = 0.02 ,
576569 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
577570 latents : Optional [torch .FloatTensor ] = None ,
578571 prompt_embeds : Optional [torch .FloatTensor ] = None ,
@@ -614,8 +607,6 @@ def __call__(
614607 Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
615608 in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
616609 passed will be used. Must be in descending order.
617- strength (`float`, *optional*, defaults to 0.8):
618- Higher strength leads to more differences between original video and generated video.
619610 guidance_scale (`float`, *optional*, defaults to 7.0):
620611 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
621612 `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -682,7 +673,6 @@ def __call__(
682673 prompt ,
683674 height ,
684675 width ,
685- strength ,
686676 negative_prompt ,
687677 callback_on_step_end_tensor_inputs ,
688678 prompt_embeds ,
@@ -730,7 +720,7 @@ def __call__(
730720 )
731721
732722 latent_channels = self .transformer .config .in_channels // 2
733- latents = self .prepare_latents (
723+ latents , image_latents = self .prepare_latents (
734724 image ,
735725 batch_size * num_videos_per_prompt ,
736726 latent_channels ,
@@ -765,6 +755,10 @@ def __call__(
765755
766756 latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
767757 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
758+
759+ latent_image_input = torch .cat ([image_latents ] * 2 ) if do_classifier_free_guidance else image_latents
760+ latent_model_input = torch .cat ([latent_model_input , latent_image_input ], dim = 2 )
761+ print (latent_model_input .shape )
768762
769763 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
770764 timestep = t .expand (latent_model_input .shape [0 ])
0 commit comments