@@ -350,6 +350,7 @@ def __init__(
350350 spatial_interpolation_scale : float = 1.875 ,
351351 temporal_interpolation_scale : float = 1.0 ,
352352 use_positional_embeddings : bool = True ,
353+ use_learned_positional_embeddings : bool = True ,
353354 ) -> None :
354355 super ().__init__ ()
355356
@@ -363,15 +364,17 @@ def __init__(
363364 self .spatial_interpolation_scale = spatial_interpolation_scale
364365 self .temporal_interpolation_scale = temporal_interpolation_scale
365366 self .use_positional_embeddings = use_positional_embeddings
367+ self .use_learned_positional_embeddings = use_learned_positional_embeddings
366368
367369 self .proj = nn .Conv2d (
368370 in_channels , embed_dim , kernel_size = (patch_size , patch_size ), stride = patch_size , bias = bias
369371 )
370372 self .text_proj = nn .Linear (text_embed_dim , embed_dim )
371373
372- if use_positional_embeddings :
374+ if use_positional_embeddings or use_learned_positional_embeddings :
375+ persistent = use_learned_positional_embeddings
373376 pos_embedding = self ._get_positional_embeddings (sample_height , sample_width , sample_frames )
374- self .register_buffer ("pos_embedding" , pos_embedding , persistent = False )
377+ self .register_buffer ("pos_embedding" , pos_embedding , persistent = persistent )
375378
376379 def _get_positional_embeddings (self , sample_height : int , sample_width : int , sample_frames : int ) -> torch .Tensor :
377380 post_patch_height = sample_height // self .patch_size
@@ -415,8 +418,15 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
415418 [text_embeds , image_embeds ], dim = 1
416419 ).contiguous () # [batch, seq_length + num_frames x height x width, channels]
417420
418- if self .use_positional_embeddings :
421+ if self .use_positional_embeddings or self .use_learned_positional_embeddings :
422+ if self .use_learned_positional_embeddings and (self .sample_width != width or self .sample_height != height ):
423+ raise ValueError (
424+ "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
425+ "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
426+ )
427+
419428 pre_time_compression_frames = (num_frames - 1 ) * self .temporal_compression_ratio + 1
429+
420430 if (
421431 self .sample_height != height
422432 or self .sample_width != width
0 commit comments