2222from transformers import T5EncoderModel , T5Tokenizer
2323
2424from ...callbacks import MultiPipelineCallbacks , PipelineCallback
25- from ...image_processor import PipelineImageInput
2625from ...models import AutoencoderKLCogVideoX , CogVideoXTransformer3DModel
2726from ...models .embeddings import get_3d_rotary_pos_embed
2827from ...pipelines .pipeline_utils import DiffusionPipeline
3938logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
4039
4140
42- def _resize_with_antialiasing (input , size , interpolation = "bicubic" , align_corners = True ):
43- h , w = input .shape [- 2 :]
44- factors = (h / size [0 ], w / size [1 ])
45-
46- # First, we have to determine sigma
47- # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
48- sigmas = (
49- max ((factors [0 ] - 1.0 ) / 2.0 , 0.001 ),
50- max ((factors [1 ] - 1.0 ) / 2.0 , 0.001 ),
51- )
52-
53- # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
54- # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
55- # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
56- ks = int (max (2.0 * 2 * sigmas [0 ], 3 )), int (max (2.0 * 2 * sigmas [1 ], 3 ))
57-
58- # Make sure it is odd
59- if (ks [0 ] % 2 ) == 0 :
60- ks = ks [0 ] + 1 , ks [1 ]
61-
62- if (ks [1 ] % 2 ) == 0 :
63- ks = ks [0 ], ks [1 ] + 1
64-
65- input = _gaussian_blur2d (input , ks , sigmas )
66-
67- output = torch .nn .functional .interpolate (input , size = size , mode = interpolation , align_corners = align_corners )
68- return output
69-
70-
71- def _gaussian_blur2d (input , kernel_size , sigma ):
72- if isinstance (sigma , tuple ):
73- sigma = torch .tensor ([sigma ], dtype = input .dtype )
74- else :
75- sigma = sigma .to (dtype = input .dtype )
76-
77- ky , kx = int (kernel_size [0 ]), int (kernel_size [1 ])
78- bs = sigma .shape [0 ]
79- kernel_x = _gaussian (kx , sigma [:, 1 ].view (bs , 1 ))
80- kernel_y = _gaussian (ky , sigma [:, 0 ].view (bs , 1 ))
81- out_x = _filter2d (input , kernel_x [..., None , :])
82- out = _filter2d (out_x , kernel_y [..., None ])
83-
84- return out
85-
86-
87- def _compute_padding (kernel_size ):
88- """Compute padding tuple."""
89- # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
90- # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
91- if len (kernel_size ) < 2 :
92- raise AssertionError (kernel_size )
93- computed = [k - 1 for k in kernel_size ]
94-
95- # for even kernels we need to do asymmetric padding :(
96- out_padding = 2 * len (kernel_size ) * [0 ]
97-
98- for i in range (len (kernel_size )):
99- computed_tmp = computed [- (i + 1 )]
100-
101- pad_front = computed_tmp // 2
102- pad_rear = computed_tmp - pad_front
103-
104- out_padding [2 * i + 0 ] = pad_front
105- out_padding [2 * i + 1 ] = pad_rear
106-
107- return out_padding
108-
109-
110- def _filter2d (input , kernel ):
111- # prepare kernel
112- b , c , h , w = input .shape
113- tmp_kernel = kernel [:, None , ...].to (device = input .device , dtype = input .dtype )
114-
115- tmp_kernel = tmp_kernel .expand (- 1 , c , - 1 , - 1 )
116-
117- height , width = tmp_kernel .shape [- 2 :]
118-
119- padding_shape : List [int ] = _compute_padding ([height , width ])
120- input = torch .nn .functional .pad (input , padding_shape , mode = "reflect" )
121-
122- # kernel and input tensor reshape to align element-wise or batch-wise params
123- tmp_kernel = tmp_kernel .reshape (- 1 , 1 , height , width )
124- input = input .view (- 1 , tmp_kernel .size (0 ), input .size (- 2 ), input .size (- 1 ))
125-
126- # convolve the tensor with the kernel.
127- output = torch .nn .functional .conv2d (input , tmp_kernel , groups = tmp_kernel .size (0 ), padding = 0 , stride = 1 )
128-
129- out = output .view (b , c , h , w )
130- return out
131-
132-
133- def _gaussian (window_size : int , sigma ):
134- if isinstance (sigma , float ):
135- sigma = torch .tensor ([[sigma ]])
136-
137- batch_size = sigma .shape [0 ]
138-
139- x = (torch .arange (window_size , device = sigma .device , dtype = sigma .dtype ) - window_size // 2 ).expand (batch_size , - 1 )
140-
141- if window_size % 2 == 0 :
142- x = x + 0.5
143-
144- gauss = torch .exp (- x .pow (2.0 ) / (2 * sigma .pow (2.0 )))
145-
146- return gauss / gauss .sum (- 1 , keepdim = True )
147-
148-
14941EXAMPLE_DOC_STRING = """
15042 Examples:
15143 ```py
@@ -285,7 +177,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
285177 """
286178
287179 _optional_components = []
288- model_cpu_offload_seq = "text_encoder->image_encoder-> transformer->vae"
180+ model_cpu_offload_seq = "text_encoder->transformer->vae"
289181
290182 _callback_tensor_inputs = [
291183 "latents" ,
@@ -297,7 +189,6 @@ def __init__(
297189 self ,
298190 tokenizer : T5Tokenizer ,
299191 text_encoder : T5EncoderModel ,
300- image_encoder : AutoencoderKLCogVideoX ,
301192 vae : AutoencoderKLCogVideoX ,
302193 transformer : CogVideoXTransformer3DModel ,
303194 scheduler : Union [CogVideoXDDIMScheduler , CogVideoXDPMScheduler ],
@@ -307,7 +198,6 @@ def __init__(
307198 self .register_modules (
308199 tokenizer = tokenizer ,
309200 text_encoder = text_encoder ,
310- image_encoder = image_encoder ,
311201 vae = vae ,
312202 transformer = transformer ,
313203 scheduler = scheduler ,
@@ -321,45 +211,6 @@ def __init__(
321211
322212 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
323213
324- def _encode_image (
325- self ,
326- image : PipelineImageInput ,
327- device : Union [str , torch .device ],
328- num_videos_per_prompt : int ,
329- do_classifier_free_guidance : bool ,
330- ) -> torch .Tensor :
331- dtype = next (self .image_encoder .parameters ()).dtype
332-
333- if not isinstance (image , torch .Tensor ):
334- image = self .video_processor .pil_to_numpy (image )
335- image = self .video_processor .numpy_to_pt (image )
336-
337- # We normalize the image before resizing to match with the original implementation.
338- # Then we unnormalize it after resizing.
339- image = image * 2.0 - 1.0
340- image = _resize_with_antialiasing (image , (224 , 224 ))
341- image = (image + 1.0 ) / 2.0
342-
343- # encode image using VAE
344- image = image .to (device = device , dtype = dtype )
345- image_embeddings = self .image_encoder (image ).image_embeds
346- image_embeddings = image_embeddings .unsqueeze (1 )
347-
348- # duplicate image embeddings for each generation per prompt, using mps friendly method
349- bs_embed , seq_len , _ = image_embeddings .shape
350- image_embeddings = image_embeddings .repeat (1 , num_videos_per_prompt , 1 )
351- image_embeddings = image_embeddings .view (bs_embed * num_videos_per_prompt , seq_len , - 1 )
352-
353- if do_classifier_free_guidance :
354- negative_image_embeddings = torch .zeros_like (image_embeddings )
355-
356- # For classifier free guidance, we need to do two forward passes.
357- # Here we concatenate the unconditional and text embeddings into a single batch
358- # to avoid doing two forward passes
359- image_embeddings = torch .cat ([negative_image_embeddings , image_embeddings ])
360-
361- return image_embeddings
362-
363214 # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
364215 def _get_t5_prompt_embeds (
365216 self ,
@@ -486,23 +337,65 @@ def encode_prompt(
486337 return prompt_embeds , negative_prompt_embeds
487338
488339 def prepare_latents (
489- self , batch_size , num_channels_latents , num_frames , height , width , dtype , device , generator , latents = None
340+ self ,
341+ image : Optional [torch .Tensor ] = None ,
342+ batch_size : int = 1 ,
343+ num_channels_latents : int = 16 ,
344+ num_frames : int = 13 ,
345+ height : int = 60 ,
346+ width : int = 90 ,
347+ dtype : Optional [torch .dtype ] = None ,
348+ device : Optional [torch .device ] = None ,
349+ generator : Optional [torch .Generator ] = None ,
350+ latents : Optional [torch .Tensor ] = None ,
490351 ):
352+ num_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
491353 shape = (
492354 batch_size ,
493- ( num_frames - 1 ) // self . vae_scale_factor_temporal + 1 ,
355+ num_frames ,
494356 num_channels_latents ,
495357 height // self .vae_scale_factor_spatial ,
496358 width // self .vae_scale_factor_spatial ,
497359 )
360+
498361 if isinstance (generator , list ) and len (generator ) != batch_size :
499362 raise ValueError (
500363 f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
501364 f" size of { batch_size } . Make sure the batch size matches the length of the generators."
502365 )
503366
504367 if latents is None :
505- latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
368+ assert image .ndim == 4
369+ image = image .unsqueeze (2 ) # [B, C, F, H, W]
370+
371+ if isinstance (generator , list ):
372+ if len (generator ) != batch_size :
373+ raise ValueError (
374+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
375+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
376+ )
377+
378+ init_latents = [
379+ retrieve_latents (self .vae .encode (image [i ].unsqueeze (0 )), generator [i ]) for i in range (batch_size )
380+ ]
381+ else :
382+ init_latents = [retrieve_latents (self .vae .encode (img .unsqueeze (0 )), generator ) for img in image ]
383+
384+ init_latents = torch .cat (init_latents , dim = 0 ).to (dtype ).permute (0 , 2 , 1 , 3 , 4 ) # [B, F, C, H, W]
385+ init_latents = self .vae .config .scaling_factor * init_latents
386+
387+ padding_shape = (
388+ batch_size ,
389+ num_frames - 1 ,
390+ num_channels_latents ,
391+ height // self .vae_scale_factor_spatial ,
392+ width // self .vae_scale_factor_spatial ,
393+ )
394+ latent_padding = torch .zeros (padding_shape , device = device , dtype = dtype )
395+ init_latents = torch .cat ([init_latents , latent_padding ], dim = 1 )
396+
397+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
398+ latents = torch .cat ([noise , init_latents ], dim = 2 )
506399 else :
507400 latents = latents .to (device )
508401
@@ -811,17 +704,7 @@ def __call__(
811704 # corresponds to doing no classifier free guidance.
812705 do_classifier_free_guidance = guidance_scale > 1.0
813706
814- # 3. Encode input prompt and image prompt
815- image_embeddings = self ._encode_image (
816- image = image ,
817- device = device ,
818- num_videos_per_prompt = num_videos_per_prompt ,
819- do_classifier_free_guidance = do_classifier_free_guidance ,
820- )
821- image = self .video_processor .preprocess (image , height = height , width = width ).to (device )
822- noise = randn_tensor (image .shape , generator = generator , device = device , dtype = image .dtype )
823- image = image + noise_aug_strength * noise
824-
707+ # 3. Encode input prompt
825708 prompt_embeds , negative_prompt_embeds = self .encode_prompt (
826709 prompt = prompt ,
827710 negative_prompt = negative_prompt ,
@@ -837,12 +720,15 @@ def __call__(
837720
838721 # 4. Prepare timesteps
839722 timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
840- latent_timestep = timesteps [:1 ].repeat (batch_size * num_videos_per_prompt )
841723 self ._num_timesteps = len (timesteps )
842724
843725 # 5. Prepare latents
726+ image = self .video_processor .preprocess (image , height = height , width = width ).to (device )
727+ image = image .unsqueeze (2 ) # [B, C, F, H, W]
728+
844729 latent_channels = self .transformer .config .in_channels
845730 latents = self .prepare_latents (
731+ image ,
846732 batch_size * num_videos_per_prompt ,
847733 latent_channels ,
848734 num_frames ,
0 commit comments