@@ -334,127 +334,12 @@ def forward(self, x, freqs_cis):
334334 freqs_cis [:height_tokens , :width_tokens ].flatten (0 , 1 ).unsqueeze (0 ),
335335 )
336336
337- class CogVideoX1_1PatchEmbed (nn .Module ):
338- def __init__ (
339- self ,
340- patch_size : int = 2 ,
341- in_channels : int = 16 ,
342- embed_dim : int = 1920 ,
343- text_embed_dim : int = 4096 ,
344- sample_width : int = 90 ,
345- sample_height : int = 60 ,
346- sample_frames : int = 81 ,
347- temporal_compression_ratio : int = 4 ,
348- max_text_seq_length : int = 226 ,
349- spatial_interpolation_scale : float = 1.875 ,
350- temporal_interpolation_scale : float = 1.0 ,
351- use_positional_embeddings : bool = True ,
352- use_learned_positional_embeddings : bool = True ,
353- ) -> None :
354- super ().__init__ ()
355-
356- # Adjust patch_size to handle three dimensions
357- self .patch_size = (patch_size , patch_size , patch_size ) # (depth, height, width)
358- self .embed_dim = embed_dim
359- self .sample_height = sample_height
360- self .sample_width = sample_width
361- self .sample_frames = sample_frames
362- self .temporal_compression_ratio = temporal_compression_ratio
363- self .max_text_seq_length = max_text_seq_length
364- self .spatial_interpolation_scale = spatial_interpolation_scale
365- self .temporal_interpolation_scale = temporal_interpolation_scale
366- self .use_positional_embeddings = use_positional_embeddings
367- self .use_learned_positional_embeddings = use_learned_positional_embeddings
368-
369- # Use Linear layer for projection
370- self .proj = nn .Linear (in_channels * (patch_size ** 3 ), embed_dim )
371- self .text_proj = nn .Linear (text_embed_dim , embed_dim )
372-
373- if use_positional_embeddings or use_learned_positional_embeddings :
374- persistent = use_learned_positional_embeddings
375- pos_embedding = self ._get_positional_embeddings (sample_height , sample_width , sample_frames )
376- self .register_buffer ("pos_embedding" , pos_embedding , persistent = persistent )
377-
378- def _get_positional_embeddings (self , sample_height : int , sample_width : int , sample_frames : int ) -> torch .Tensor :
379- post_patch_height = sample_height // self .patch_size [1 ]
380- post_patch_width = sample_width // self .patch_size [2 ]
381- post_time_compression_frames = (sample_frames - 1 ) // self .temporal_compression_ratio + 1
382- num_patches = post_patch_height * post_patch_width * post_time_compression_frames
383-
384- pos_embedding = get_3d_sincos_pos_embed (
385- self .embed_dim ,
386- (post_patch_width , post_patch_height ),
387- post_time_compression_frames ,
388- self .spatial_interpolation_scale ,
389- self .temporal_interpolation_scale ,
390- )
391- pos_embedding = torch .from_numpy (pos_embedding ).flatten (0 , 1 )
392- joint_pos_embedding = torch .zeros (1 , self .max_text_seq_length + num_patches , self .embed_dim , requires_grad = False )
393- joint_pos_embedding .data [:, self .max_text_seq_length :].copy_ (pos_embedding )
394-
395- return joint_pos_embedding
396-
397- def forward (self , text_embeds : torch .Tensor , image_embeds : torch .Tensor ):
398- """
399- Args:
400- text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim).
401- image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width).
402- """
403- text_embeds = self .text_proj (text_embeds )
404- first_frame = image_embeds [:, 0 :1 , :, :, :]
405- duplicated_first_frame = first_frame .repeat (1 , 2 , 1 , 1 , 1 ) # (batch, 2, channels, height, width)
406- # Copy the first frames, for t_patch
407- image_embeds = torch .cat ([duplicated_first_frame , image_embeds [:, 1 :, :, :, :]], dim = 1 )
408- batch , num_frames , channels , height , width = image_embeds .shape
409- image_embeds = image_embeds .permute (0 , 2 , 1 , 3 , 4 ).contiguous ()
410- image_embeds = image_embeds .view (batch , channels , - 1 ).permute (0 , 2 , 1 )
411-
412- rope_patch_t = num_frames // self .patch_size [0 ]
413- rope_patch_h = height // self .patch_size [1 ]
414- rope_patch_w = width // self .patch_size [2 ]
415-
416- image_embeds = image_embeds .view (
417- batch ,
418- rope_patch_t , self .patch_size [0 ],
419- rope_patch_h , self .patch_size [1 ],
420- rope_patch_w , self .patch_size [2 ],
421- channels
422- )
423- image_embeds = image_embeds .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).contiguous ()
424- image_embeds = image_embeds .view (batch , rope_patch_t * rope_patch_h * rope_patch_w , - 1 )
425- image_embeds = self .proj (image_embeds )
426- # Concatenate text and image embeddings
427- embeds = torch .cat ([text_embeds , image_embeds ], dim = 1 ).contiguous ()
428-
429- # Add positional embeddings if applicable
430- if self .use_positional_embeddings or self .use_learned_positional_embeddings :
431- if self .use_learned_positional_embeddings and (self .sample_width != width or self .sample_height != height ):
432- raise ValueError (
433- "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
434- "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
435- )
436-
437- pre_time_compression_frames = (num_frames - 1 ) * self .temporal_compression_ratio + 1
438-
439- if (
440- self .sample_height != height
441- or self .sample_width != width
442- or self .sample_frames != pre_time_compression_frames
443- ):
444- pos_embedding = self ._get_positional_embeddings (height , width , pre_time_compression_frames )
445- pos_embedding = pos_embedding .to (embeds .device , dtype = embeds .dtype )
446- else :
447- pos_embedding = self .pos_embedding
448-
449- embeds = embeds + pos_embedding
450-
451- return embeds
452-
453337
454338class CogVideoXPatchEmbed (nn .Module ):
455339 def __init__ (
456340 self ,
457341 patch_size : int = 2 ,
342+ patch_size_t : Optional [int ] = None ,
458343 in_channels : int = 16 ,
459344 embed_dim : int = 1920 ,
460345 text_embed_dim : int = 4096 ,
@@ -472,6 +357,7 @@ def __init__(
472357 super ().__init__ ()
473358
474359 self .patch_size = patch_size
360+ self .patch_size_t = patch_size_t
475361 self .embed_dim = embed_dim
476362 self .sample_height = sample_height
477363 self .sample_width = sample_width
@@ -483,9 +369,15 @@ def __init__(
483369 self .use_positional_embeddings = use_positional_embeddings
484370 self .use_learned_positional_embeddings = use_learned_positional_embeddings
485371
486- self .proj = nn .Conv2d (
487- in_channels , embed_dim , kernel_size = (patch_size , patch_size ), stride = patch_size , bias = bias
488- )
372+ if patch_size_t is None :
373+ # CogVideoX 1.0 checkpoints
374+ self .proj = nn .Conv2d (
375+ in_channels , embed_dim , kernel_size = (patch_size , patch_size ), stride = patch_size , bias = bias
376+ )
377+ else :
378+ # CogVideoX 1.5 checkpoints
379+ self .proj = nn .Linear (in_channels * patch_size * patch_size * patch_size_t , embed_dim )
380+
489381 self .text_proj = nn .Linear (text_embed_dim , embed_dim )
490382
491383 if use_positional_embeddings or use_learned_positional_embeddings :
@@ -524,12 +416,22 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
524416 """
525417 text_embeds = self .text_proj (text_embeds )
526418
527- batch , num_frames , channels , height , width = image_embeds .shape
528- image_embeds = image_embeds .reshape (- 1 , channels , height , width )
529- image_embeds = self .proj (image_embeds )
530- image_embeds = image_embeds .view (batch , num_frames , * image_embeds .shape [1 :])
531- image_embeds = image_embeds .flatten (3 ).transpose (2 , 3 ) # [batch, num_frames, height x width, channels]
532- image_embeds = image_embeds .flatten (1 , 2 ) # [batch, num_frames x height x width, channels]
419+ batch_size , num_frames , channels , height , width = image_embeds .shape
420+
421+ if self .patch_size_t is None :
422+ image_embeds = image_embeds .reshape (- 1 , channels , height , width )
423+ image_embeds = self .proj (image_embeds )
424+ image_embeds = image_embeds .view (batch_size , num_frames , * image_embeds .shape [1 :])
425+ image_embeds = image_embeds .flatten (3 ).transpose (2 , 3 ) # [batch, num_frames, height x width, channels]
426+ image_embeds = image_embeds .flatten (1 , 2 ) # [batch, num_frames x height x width, channels]
427+ else :
428+ p = self .patch_size
429+ p_t = self .patch_size_t
430+
431+ image_embeds = image_embeds .permute (0 , 1 , 3 , 4 , 2 )
432+ image_embeds = image_embeds .reshape (batch_size , num_frames // p_t , p_t , height // p , p , width // p , p , channels )
433+ image_embeds = image_embeds .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (4 , 7 ).flatten (1 , 3 )
434+ image_embeds = self .proj (image_embeds )
533435
534436 embeds = torch .cat (
535437 [text_embeds , image_embeds ], dim = 1
0 commit comments