1414
1515import inspect
1616from dataclasses import dataclass
17- from typing import Any , Callable , Dict , List , Optional , Union
17+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1818
1919import PIL .Image
2020import torch
7575
7676 >>> # Generate video
7777 >>> generator = torch.Generator("cuda").manual_seed(0)
78+ >>> # Text-only conditioning is also supported without the need to pass `conditions`
7879 >>> video = pipe(
7980 ... conditions=[condition1, condition2],
8081 ... prompt=prompt,
@@ -223,7 +224,7 @@ def retrieve_latents(
223224
224225class LTXConditionPipeline (DiffusionPipeline , FromSingleFileMixin , LTXVideoLoraLoaderMixin ):
225226 r"""
226- Pipeline for image-to-video generation.
227+ Pipeline for text/ image/video -to-video generation.
227228
228229 Reference: https://github.com/Lightricks/LTX-Video
229230
@@ -482,9 +483,6 @@ def check_inputs(
482483 if conditions is not None and (image is not None or video is not None ):
483484 raise ValueError ("If `conditions` is provided, `image` and `video` must not be provided." )
484485
485- if conditions is None and (image is None and video is None ):
486- raise ValueError ("If `conditions` is not provided, `image` or `video` must be provided." )
487-
488486 if conditions is None :
489487 if isinstance (image , list ) and isinstance (frame_index , list ) and len (image ) != len (frame_index ):
490488 raise ValueError (
@@ -642,9 +640,9 @@ def add_noise_to_image_conditioning_latents(
642640
643641 def prepare_latents (
644642 self ,
645- conditions : List [torch .Tensor ],
646- condition_strength : List [float ],
647- condition_frame_index : List [int ],
643+ conditions : Optional [ List [torch .Tensor ]] = None ,
644+ condition_strength : Optional [ List [float ]] = None ,
645+ condition_frame_index : Optional [ List [int ]] = None ,
648646 batch_size : int = 1 ,
649647 num_channels_latents : int = 128 ,
650648 height : int = 512 ,
@@ -654,85 +652,88 @@ def prepare_latents(
654652 generator : Optional [torch .Generator ] = None ,
655653 device : Optional [torch .device ] = None ,
656654 dtype : Optional [torch .dtype ] = None ,
657- ) -> None :
655+ ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor , int ] :
658656 num_latent_frames = (num_frames - 1 ) // self .vae_temporal_compression_ratio + 1
659657 latent_height = height // self .vae_spatial_compression_ratio
660658 latent_width = width // self .vae_spatial_compression_ratio
661659
662660 shape = (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
663661 latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
664662
665- condition_latent_frames_mask = torch .zeros ((batch_size , num_latent_frames ), device = device , dtype = torch .float32 )
666-
667- extra_conditioning_latents = []
668- extra_conditioning_video_ids = []
669- extra_conditioning_mask = []
670- extra_conditioning_num_latents = 0
671- for data , strength , frame_index in zip (conditions , condition_strength , condition_frame_index ):
672- condition_latents = retrieve_latents (self .vae .encode (data ), generator = generator )
673- condition_latents = self ._normalize_latents (
674- condition_latents , self .vae .latents_mean , self .vae .latents_std
675- ).to (device , dtype = dtype )
676-
677- num_data_frames = data .size (2 )
678- num_cond_frames = condition_latents .size (2 )
679-
680- if frame_index == 0 :
681- latents [:, :, :num_cond_frames ] = torch .lerp (
682- latents [:, :, :num_cond_frames ], condition_latents , strength
683- )
684- condition_latent_frames_mask [:, :num_cond_frames ] = strength
663+ if len (conditions ) > 0 :
664+ condition_latent_frames_mask = torch .zeros (
665+ (batch_size , num_latent_frames ), device = device , dtype = torch .float32
666+ )
685667
686- else :
687- if num_data_frames > 1 :
688- if num_cond_frames < num_prefix_latent_frames :
689- raise ValueError (
690- f"Number of latent frames must be at least { num_prefix_latent_frames } but got { num_data_frames } ."
691- )
692-
693- if num_cond_frames > num_prefix_latent_frames :
694- start_frame = frame_index // self .vae_temporal_compression_ratio + num_prefix_latent_frames
695- end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
696- latents [:, :, start_frame :end_frame ] = torch .lerp (
697- latents [:, :, start_frame :end_frame ],
698- condition_latents [:, :, num_prefix_latent_frames :],
699- strength ,
700- )
701- condition_latent_frames_mask [:, start_frame :end_frame ] = strength
702- condition_latents = condition_latents [:, :, :num_prefix_latent_frames ]
703-
704- noise = randn_tensor (condition_latents .shape , generator = generator , device = device , dtype = dtype )
705- condition_latents = torch .lerp (noise , condition_latents , strength )
706-
707- condition_video_ids = self ._prepare_video_ids (
708- batch_size ,
709- condition_latents .size (2 ),
710- latent_height ,
711- latent_width ,
712- patch_size = self .transformer_spatial_patch_size ,
713- patch_size_t = self .transformer_temporal_patch_size ,
714- device = device ,
715- )
716- condition_video_ids = self ._scale_video_ids (
717- condition_video_ids ,
718- scale_factor = self .vae_spatial_compression_ratio ,
719- scale_factor_t = self .vae_temporal_compression_ratio ,
720- frame_index = frame_index ,
721- device = device ,
722- )
723- condition_latents = self ._pack_latents (
724- condition_latents ,
725- self .transformer_spatial_patch_size ,
726- self .transformer_temporal_patch_size ,
727- )
728- condition_conditioning_mask = torch .full (
729- condition_latents .shape [:2 ], strength , device = device , dtype = dtype
730- )
668+ extra_conditioning_latents = []
669+ extra_conditioning_video_ids = []
670+ extra_conditioning_mask = []
671+ extra_conditioning_num_latents = 0
672+ for data , strength , frame_index in zip (conditions , condition_strength , condition_frame_index ):
673+ condition_latents = retrieve_latents (self .vae .encode (data ), generator = generator )
674+ condition_latents = self ._normalize_latents (
675+ condition_latents , self .vae .latents_mean , self .vae .latents_std
676+ ).to (device , dtype = dtype )
677+
678+ num_data_frames = data .size (2 )
679+ num_cond_frames = condition_latents .size (2 )
680+
681+ if frame_index == 0 :
682+ latents [:, :, :num_cond_frames ] = torch .lerp (
683+ latents [:, :, :num_cond_frames ], condition_latents , strength
684+ )
685+ condition_latent_frames_mask [:, :num_cond_frames ] = strength
686+
687+ else :
688+ if num_data_frames > 1 :
689+ if num_cond_frames < num_prefix_latent_frames :
690+ raise ValueError (
691+ f"Number of latent frames must be at least { num_prefix_latent_frames } but got { num_data_frames } ."
692+ )
693+
694+ if num_cond_frames > num_prefix_latent_frames :
695+ start_frame = frame_index // self .vae_temporal_compression_ratio + num_prefix_latent_frames
696+ end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
697+ latents [:, :, start_frame :end_frame ] = torch .lerp (
698+ latents [:, :, start_frame :end_frame ],
699+ condition_latents [:, :, num_prefix_latent_frames :],
700+ strength ,
701+ )
702+ condition_latent_frames_mask [:, start_frame :end_frame ] = strength
703+ condition_latents = condition_latents [:, :, :num_prefix_latent_frames ]
704+
705+ noise = randn_tensor (condition_latents .shape , generator = generator , device = device , dtype = dtype )
706+ condition_latents = torch .lerp (noise , condition_latents , strength )
707+
708+ condition_video_ids = self ._prepare_video_ids (
709+ batch_size ,
710+ condition_latents .size (2 ),
711+ latent_height ,
712+ latent_width ,
713+ patch_size = self .transformer_spatial_patch_size ,
714+ patch_size_t = self .transformer_temporal_patch_size ,
715+ device = device ,
716+ )
717+ condition_video_ids = self ._scale_video_ids (
718+ condition_video_ids ,
719+ scale_factor = self .vae_spatial_compression_ratio ,
720+ scale_factor_t = self .vae_temporal_compression_ratio ,
721+ frame_index = frame_index ,
722+ device = device ,
723+ )
724+ condition_latents = self ._pack_latents (
725+ condition_latents ,
726+ self .transformer_spatial_patch_size ,
727+ self .transformer_temporal_patch_size ,
728+ )
729+ condition_conditioning_mask = torch .full (
730+ condition_latents .shape [:2 ], strength , device = device , dtype = dtype
731+ )
731732
732- extra_conditioning_latents .append (condition_latents )
733- extra_conditioning_video_ids .append (condition_video_ids )
734- extra_conditioning_mask .append (condition_conditioning_mask )
735- extra_conditioning_num_latents += condition_latents .size (1 )
733+ extra_conditioning_latents .append (condition_latents )
734+ extra_conditioning_video_ids .append (condition_video_ids )
735+ extra_conditioning_mask .append (condition_conditioning_mask )
736+ extra_conditioning_num_latents += condition_latents .size (1 )
736737
737738 video_ids = self ._prepare_video_ids (
738739 batch_size ,
@@ -743,7 +744,10 @@ def prepare_latents(
743744 patch_size = self .transformer_spatial_patch_size ,
744745 device = device ,
745746 )
746- conditioning_mask = condition_latent_frames_mask .gather (1 , video_ids [:, 0 ])
747+ if len (conditions ) > 0 :
748+ conditioning_mask = condition_latent_frames_mask .gather (1 , video_ids [:, 0 ])
749+ else :
750+ conditioning_mask , extra_conditioning_num_latents = None , 0
747751 video_ids = self ._scale_video_ids (
748752 video_ids ,
749753 scale_factor = self .vae_spatial_compression_ratio ,
@@ -755,7 +759,7 @@ def prepare_latents(
755759 latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size
756760 )
757761
758- if len (extra_conditioning_latents ) > 0 :
762+ if len (conditions ) > 0 and len ( extra_conditioning_latents ) > 0 :
759763 latents = torch .cat ([* extra_conditioning_latents , latents ], dim = 1 )
760764 video_ids = torch .cat ([* extra_conditioning_video_ids , video_ids ], dim = 2 )
761765 conditioning_mask = torch .cat ([* extra_conditioning_mask , conditioning_mask ], dim = 1 )
@@ -955,7 +959,7 @@ def __call__(
955959 frame_index = [condition .frame_index for condition in conditions ]
956960 image = [condition .image for condition in conditions ]
957961 video = [condition .video for condition in conditions ]
958- else :
962+ elif image is not None or video is not None :
959963 if not isinstance (image , list ):
960964 image = [image ]
961965 num_conditions = 1
@@ -999,32 +1003,34 @@ def __call__(
9991003 vae_dtype = self .vae .dtype
10001004
10011005 conditioning_tensors = []
1002- for condition_image , condition_video , condition_frame_index , condition_strength in zip (
1003- image , video , frame_index , strength
1004- ):
1005- if condition_image is not None :
1006- condition_tensor = (
1007- self .video_processor .preprocess (condition_image , height , width )
1008- .unsqueeze (2 )
1009- .to (device , dtype = vae_dtype )
1010- )
1011- elif condition_video is not None :
1012- condition_tensor = self .video_processor .preprocess_video (condition_video , height , width )
1013- num_frames_input = condition_tensor .size (2 )
1014- num_frames_output = self .trim_conditioning_sequence (
1015- condition_frame_index , num_frames_input , num_frames
1016- )
1017- condition_tensor = condition_tensor [:, :, :num_frames_output ]
1018- condition_tensor = condition_tensor .to (device , dtype = vae_dtype )
1019- else :
1020- raise ValueError ("Either `image` or `video` must be provided in the `LTXVideoCondition`." )
1021-
1022- if condition_tensor .size (2 ) % self .vae_temporal_compression_ratio != 1 :
1023- raise ValueError (
1024- f"Number of frames in the video must be of the form (k * { self .vae_temporal_compression_ratio } + 1) "
1025- f"but got { condition_tensor .size (2 )} frames."
1026- )
1027- conditioning_tensors .append (condition_tensor )
1006+ is_conditioning_image_or_video = image is not None or video is not None
1007+ if is_conditioning_image_or_video :
1008+ for condition_image , condition_video , condition_frame_index , condition_strength in zip (
1009+ image , video , frame_index , strength
1010+ ):
1011+ if condition_image is not None :
1012+ condition_tensor = (
1013+ self .video_processor .preprocess (condition_image , height , width )
1014+ .unsqueeze (2 )
1015+ .to (device , dtype = vae_dtype )
1016+ )
1017+ elif condition_video is not None :
1018+ condition_tensor = self .video_processor .preprocess_video (condition_video , height , width )
1019+ num_frames_input = condition_tensor .size (2 )
1020+ num_frames_output = self .trim_conditioning_sequence (
1021+ condition_frame_index , num_frames_input , num_frames
1022+ )
1023+ condition_tensor = condition_tensor [:, :, :num_frames_output ]
1024+ condition_tensor = condition_tensor .to (device , dtype = vae_dtype )
1025+ else :
1026+ raise ValueError ("Either `image` or `video` must be provided for conditioning." )
1027+
1028+ if condition_tensor .size (2 ) % self .vae_temporal_compression_ratio != 1 :
1029+ raise ValueError (
1030+ f"Number of frames in the video must be of the form (k * { self .vae_temporal_compression_ratio } + 1) "
1031+ f"but got { condition_tensor .size (2 )} frames."
1032+ )
1033+ conditioning_tensors .append (condition_tensor )
10281034
10291035 # 4. Prepare latent variables
10301036 num_channels_latents = self .transformer .config .in_channels
@@ -1045,7 +1051,7 @@ def __call__(
10451051 video_coords = video_coords .float ()
10461052 video_coords [:, 0 ] = video_coords [:, 0 ] * (1.0 / frame_rate )
10471053
1048- init_latents = latents .clone ()
1054+ init_latents = latents .clone () if is_conditioning_image_or_video else None
10491055
10501056 if self .do_classifier_free_guidance :
10511057 video_coords = torch .cat ([video_coords , video_coords ], dim = 0 )
@@ -1065,15 +1071,15 @@ def __call__(
10651071 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
10661072 self ._num_timesteps = len (timesteps )
10671073
1068- # 7 . Denoising loop
1074+ # 6 . Denoising loop
10691075 with self .progress_bar (total = num_inference_steps ) as progress_bar :
10701076 for i , t in enumerate (timesteps ):
10711077 if self .interrupt :
10721078 continue
10731079
10741080 self ._current_timestep = t
10751081
1076- if image_cond_noise_scale > 0 :
1082+ if image_cond_noise_scale > 0 and init_latents is not None :
10771083 # Add timestep-dependent noise to the hard-conditioning latents
10781084 # This helps with motion continuity, especially when conditioned on a single frame
10791085 latents = self .add_noise_to_image_conditioning_latents (
@@ -1086,16 +1092,18 @@ def __call__(
10861092 )
10871093
10881094 latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
1089- conditioning_mask_model_input = (
1090- torch .cat ([conditioning_mask , conditioning_mask ])
1091- if self .do_classifier_free_guidance
1092- else conditioning_mask
1093- )
1095+ if is_conditioning_image_or_video :
1096+ conditioning_mask_model_input = (
1097+ torch .cat ([conditioning_mask , conditioning_mask ])
1098+ if self .do_classifier_free_guidance
1099+ else conditioning_mask
1100+ )
10941101 latent_model_input = latent_model_input .to (prompt_embeds .dtype )
10951102
10961103 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
10971104 timestep = t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ).float ()
1098- timestep = torch .min (timestep , (1 - conditioning_mask_model_input ) * 1000.0 )
1105+ if is_conditioning_image_or_video :
1106+ timestep = torch .min (timestep , (1 - conditioning_mask_model_input ) * 1000.0 )
10991107
11001108 noise_pred = self .transformer (
11011109 hidden_states = latent_model_input ,
@@ -1115,8 +1123,11 @@ def __call__(
11151123 denoised_latents = self .scheduler .step (
11161124 - noise_pred , t , latents , per_token_timesteps = timestep , return_dict = False
11171125 )[0 ]
1118- tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask )).unsqueeze (- 1 )
1119- latents = torch .where (tokens_to_denoise_mask , denoised_latents , latents )
1126+ if is_conditioning_image_or_video :
1127+ tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask )).unsqueeze (- 1 )
1128+ latents = torch .where (tokens_to_denoise_mask , denoised_latents , latents )
1129+ else :
1130+ latents = denoised_latents
11201131
11211132 if callback_on_step_end is not None :
11221133 callback_kwargs = {}
@@ -1134,7 +1145,9 @@ def __call__(
11341145 if XLA_AVAILABLE :
11351146 xm .mark_step ()
11361147
1137- latents = latents [:, extra_conditioning_num_latents :]
1148+ if is_conditioning_image_or_video :
1149+ latents = latents [:, extra_conditioning_num_latents :]
1150+
11381151 latents = self ._unpack_latents (
11391152 latents ,
11401153 latent_num_frames ,
0 commit comments