1414# limitations under the License.
1515
1616import inspect
17-
18- import PIL
1917import math
2018from typing import Callable , Dict , List , Optional , Tuple , Union
19+
20+ import PIL
2121import torch
22- from PIL import Image
2322from transformers import T5EncoderModel , T5Tokenizer
2423
25- from ...image_processor import PipelineImageInput
2624from ...callbacks import MultiPipelineCallbacks , PipelineCallback
25+ from ...image_processor import PipelineImageInput
2726from ...models import AutoencoderKLCogVideoX , CogVideoXTransformer3DModel
2827from ...models .embeddings import get_3d_rotary_pos_embed
2928from ...pipelines .pipeline_utils import DiffusionPipeline
3635from ...video_processor import VideoProcessor
3736from .pipeline_output import CogVideoXPipelineOutput
3837
38+
3939logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
4040
4141
@@ -157,14 +157,12 @@ def _gaussian(window_size: int, sigma):
157157 >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
158158 >>> pipe.to("cuda")
159159 >>> pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
160-
160+
161161 >>> image = load_image(
162162 ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg"
163163 ... )
164164 >>> image = image.resize((720, 480))
165- >>> video = pipe(
166- ... image=image, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50
167- ... ).frames[0]
165+ >>> video = pipe(image=image, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50).frames[0]
168166 >>> export_to_video(frames, "output.mp4", fps=8)
169167 ```
170168"""
@@ -191,12 +189,12 @@ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
191189
192190# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
193191def retrieve_timesteps (
194- scheduler ,
195- num_inference_steps : Optional [int ] = None ,
196- device : Optional [Union [str , torch .device ]] = None ,
197- timesteps : Optional [List [int ]] = None ,
198- sigmas : Optional [List [float ]] = None ,
199- ** kwargs ,
192+ scheduler ,
193+ num_inference_steps : Optional [int ] = None ,
194+ device : Optional [Union [str , torch .device ]] = None ,
195+ timesteps : Optional [List [int ]] = None ,
196+ sigmas : Optional [List [float ]] = None ,
197+ ** kwargs ,
200198):
201199 """
202200 Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
@@ -251,7 +249,7 @@ def retrieve_timesteps(
251249
252250# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
253251def retrieve_latents (
254- encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
252+ encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
255253):
256254 if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
257255 return encoder_output .latent_dist .sample (generator )
@@ -296,13 +294,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
296294 ]
297295
298296 def __init__ (
299- self ,
300- tokenizer : T5Tokenizer ,
301- text_encoder : T5EncoderModel ,
302- image_encoder : AutoencoderKLCogVideoX ,
303- vae : AutoencoderKLCogVideoX ,
304- transformer : CogVideoXTransformer3DModel ,
305- scheduler : Union [CogVideoXDDIMScheduler , CogVideoXDPMScheduler ],
297+ self ,
298+ tokenizer : T5Tokenizer ,
299+ text_encoder : T5EncoderModel ,
300+ image_encoder : AutoencoderKLCogVideoX ,
301+ vae : AutoencoderKLCogVideoX ,
302+ transformer : CogVideoXTransformer3DModel ,
303+ scheduler : Union [CogVideoXDDIMScheduler , CogVideoXDPMScheduler ],
306304 ):
307305 super ().__init__ ()
308306
@@ -312,7 +310,7 @@ def __init__(
312310 image_encoder = image_encoder ,
313311 vae = vae ,
314312 transformer = transformer ,
315- scheduler = scheduler
313+ scheduler = scheduler ,
316314 )
317315 self .vae_scale_factor_spatial = (
318316 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if hasattr (self , "vae" ) and self .vae is not None else 8
@@ -324,11 +322,11 @@ def __init__(
324322 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
325323
326324 def _encode_image (
327- self ,
328- image : PipelineImageInput ,
329- device : Union [str , torch .device ],
330- num_videos_per_prompt : int ,
331- do_classifier_free_guidance : bool ,
325+ self ,
326+ image : PipelineImageInput ,
327+ device : Union [str , torch .device ],
328+ num_videos_per_prompt : int ,
329+ do_classifier_free_guidance : bool ,
332330 ) -> torch .Tensor :
333331 dtype = next (self .image_encoder .parameters ()).dtype
334332
@@ -342,7 +340,6 @@ def _encode_image(
342340 image = _resize_with_antialiasing (image , (224 , 224 ))
343341 image = (image + 1.0 ) / 2.0
344342
345-
346343 # encode image using VAE
347344 image = image .to (device = device , dtype = dtype )
348345 image_embeddings = self .image_encoder (image ).image_embeds
@@ -365,12 +362,12 @@ def _encode_image(
365362
366363 # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
367364 def _get_t5_prompt_embeds (
368- self ,
369- prompt : Union [str , List [str ]] = None ,
370- num_videos_per_prompt : int = 1 ,
371- max_sequence_length : int = 226 ,
372- device : Optional [torch .device ] = None ,
373- dtype : Optional [torch .dtype ] = None ,
365+ self ,
366+ prompt : Union [str , List [str ]] = None ,
367+ num_videos_per_prompt : int = 1 ,
368+ max_sequence_length : int = 226 ,
369+ device : Optional [torch .device ] = None ,
370+ dtype : Optional [torch .dtype ] = None ,
374371 ):
375372 device = device or self ._execution_device
376373 dtype = dtype or self .text_encoder .dtype
@@ -390,7 +387,7 @@ def _get_t5_prompt_embeds(
390387 untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
391388
392389 if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
393- removed_text = self .tokenizer .batch_decode (untruncated_ids [:, max_sequence_length - 1 : - 1 ])
390+ removed_text = self .tokenizer .batch_decode (untruncated_ids [:, max_sequence_length - 1 : - 1 ])
394391 logger .warning (
395392 "The following part of your input was truncated because `max_sequence_length` is set to "
396393 f" { max_sequence_length } tokens: { removed_text } "
@@ -408,16 +405,16 @@ def _get_t5_prompt_embeds(
408405
409406 # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
410407 def encode_prompt (
411- self ,
412- prompt : Union [str , List [str ]],
413- negative_prompt : Optional [Union [str , List [str ]]] = None ,
414- do_classifier_free_guidance : bool = True ,
415- num_videos_per_prompt : int = 1 ,
416- prompt_embeds : Optional [torch .Tensor ] = None ,
417- negative_prompt_embeds : Optional [torch .Tensor ] = None ,
418- max_sequence_length : int = 226 ,
419- device : Optional [torch .device ] = None ,
420- dtype : Optional [torch .dtype ] = None ,
408+ self ,
409+ prompt : Union [str , List [str ]],
410+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
411+ do_classifier_free_guidance : bool = True ,
412+ num_videos_per_prompt : int = 1 ,
413+ prompt_embeds : Optional [torch .Tensor ] = None ,
414+ negative_prompt_embeds : Optional [torch .Tensor ] = None ,
415+ max_sequence_length : int = 226 ,
416+ device : Optional [torch .device ] = None ,
417+ dtype : Optional [torch .dtype ] = None ,
421418 ):
422419 r"""
423420 Encodes the prompt into text encoder hidden states.
@@ -489,15 +486,7 @@ def encode_prompt(
489486 return prompt_embeds , negative_prompt_embeds
490487
491488 def prepare_latents (
492- self ,
493- batch_size ,
494- num_channels_latents ,
495- num_frames ,
496- height , width ,
497- dtype ,
498- device ,
499- generator ,
500- latents = None
489+ self , batch_size , num_channels_latents , num_frames , height , width , dtype , device , generator , latents = None
501490 ):
502491 shape = (
503492 batch_size ,
@@ -535,7 +524,7 @@ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
535524 init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
536525
537526 t_start = max (num_inference_steps - init_timestep , 0 )
538- timesteps = timesteps [t_start * self .scheduler .order :]
527+ timesteps = timesteps [t_start * self .scheduler .order :]
539528
540529 return timesteps , num_inference_steps - t_start
541530
@@ -558,17 +547,17 @@ def prepare_extra_step_kwargs(self, generator, eta):
558547 return extra_step_kwargs
559548
560549 def check_inputs (
561- self ,
562- prompt ,
563- height ,
564- width ,
565- strength ,
566- negative_prompt ,
567- callback_on_step_end_tensor_inputs ,
568- video = None ,
569- latents = None ,
570- prompt_embeds = None ,
571- negative_prompt_embeds = None ,
550+ self ,
551+ prompt ,
552+ height ,
553+ width ,
554+ strength ,
555+ negative_prompt ,
556+ callback_on_step_end_tensor_inputs ,
557+ video = None ,
558+ latents = None ,
559+ prompt_embeds = None ,
560+ negative_prompt_embeds = None ,
572561 ):
573562 if height % 8 != 0 or width % 8 != 0 :
574563 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -577,7 +566,7 @@ def check_inputs(
577566 raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
578567
579568 if callback_on_step_end_tensor_inputs is not None and not all (
580- k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
569+ k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
581570 ):
582571 raise ValueError (
583572 f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
@@ -634,11 +623,11 @@ def unfuse_qkv_projections(self) -> None:
634623
635624 # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
636625 def _prepare_rotary_positional_embeddings (
637- self ,
638- height : int ,
639- width : int ,
640- num_frames : int ,
641- device : torch .device ,
626+ self ,
627+ height : int ,
628+ width : int ,
629+ num_frames : int ,
630+ device : torch .device ,
642631 ) -> Tuple [torch .Tensor , torch .Tensor ]:
643632 grid_height = height // (self .vae_scale_factor_spatial * self .transformer .config .patch_size )
644633 grid_width = width // (self .vae_scale_factor_spatial * self .transformer .config .patch_size )
@@ -674,32 +663,32 @@ def interrupt(self):
674663 @torch .no_grad ()
675664 @replace_example_docstring (EXAMPLE_DOC_STRING )
676665 def __call__ (
677- self ,
678- image : Union [PIL .Image .Image , List [PIL .Image .Image ], torch .Tensor ],
679- prompt : Optional [Union [str , List [str ]]] = None ,
680- negative_prompt : Optional [Union [str , List [str ]]] = None ,
681- height : int = 480 ,
682- width : int = 720 ,
683- num_frames : int = 49 ,
684- num_inference_steps : int = 50 ,
685- timesteps : Optional [List [int ]] = None ,
686- strength : float = 0.8 ,
687- guidance_scale : float = 6 ,
688- use_dynamic_cfg : bool = False ,
689- num_videos_per_prompt : int = 1 ,
690- eta : float = 0.0 ,
691- noise_aug_strength : float = 0.02 ,
692- generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
693- latents : Optional [torch .FloatTensor ] = None ,
694- prompt_embeds : Optional [torch .FloatTensor ] = None ,
695- negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
696- output_type : str = "pil" ,
697- return_dict : bool = True ,
698- callback_on_step_end : Optional [
699- Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
700- ] = None ,
701- callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
702- max_sequence_length : int = 226 ,
666+ self ,
667+ image : Union [PIL .Image .Image , List [PIL .Image .Image ], torch .Tensor ],
668+ prompt : Optional [Union [str , List [str ]]] = None ,
669+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
670+ height : int = 480 ,
671+ width : int = 720 ,
672+ num_frames : int = 49 ,
673+ num_inference_steps : int = 50 ,
674+ timesteps : Optional [List [int ]] = None ,
675+ strength : float = 0.8 ,
676+ guidance_scale : float = 6 ,
677+ use_dynamic_cfg : bool = False ,
678+ num_videos_per_prompt : int = 1 ,
679+ eta : float = 0.0 ,
680+ noise_aug_strength : float = 0.02 ,
681+ generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
682+ latents : Optional [torch .FloatTensor ] = None ,
683+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
684+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
685+ output_type : str = "pil" ,
686+ return_dict : bool = True ,
687+ callback_on_step_end : Optional [
688+ Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
689+ ] = None ,
690+ callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
691+ max_sequence_length : int = 226 ,
703692 ) -> Union [CogVideoXPipelineOutput , Tuple ]:
704693 """
705694 Function invoked when calling the pipeline for generation.
@@ -827,7 +816,7 @@ def __call__(
827816 image = image ,
828817 device = device ,
829818 num_videos_per_prompt = num_videos_per_prompt ,
830- do_classifier_free_guidance = do_classifier_free_guidance
819+ do_classifier_free_guidance = do_classifier_free_guidance ,
831820 )
832821 image = self .video_processor .preprocess (image , height = height , width = width ).to (device )
833822 noise = randn_tensor (image .shape , generator = generator , device = device , dtype = image .dtype )
@@ -904,8 +893,7 @@ def __call__(
904893 # perform guidance
905894 if use_dynamic_cfg :
906895 self ._guidance_scale = 1 + guidance_scale * (
907- (1 - math .cos (
908- math .pi * ((num_inference_steps - t .item ()) / num_inference_steps ) ** 5.0 )) / 2
896+ (1 - math .cos (math .pi * ((num_inference_steps - t .item ()) / num_inference_steps ) ** 5.0 )) / 2
909897 )
910898 if do_classifier_free_guidance :
911899 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
0 commit comments