Skip to content

Commit 21a6f79

Browse files
committed
make use of learned positional embeddings
1 parent 4f89426 commit 21a6f79

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
8484
"mixins.final_layer.norm_final": "norm_out.norm",
8585
"mixins.final_layer.linear": "proj_out",
8686
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
87+
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
8788
}
8889

8990
TRANSFORMER_SPECIAL_KEYS_REMAP = {
@@ -95,8 +96,6 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
9596
"freqs_sin": remove_keys_inplace,
9697
"freqs_cos": remove_keys_inplace,
9798
"position_embedding": remove_keys_inplace,
98-
# TODO zRzRzRzRzRzRzR: really need to remove?
99-
"pos_embedding": remove_keys_inplace,
10099
}
101100

102101
VAE_KEYS_RENAME_DICT = {
@@ -150,6 +149,7 @@ def convert_transformer(
150149
num_layers=num_layers,
151150
num_attention_heads=num_attention_heads,
152151
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
152+
use_learned_positional_embeddings=i2v,
153153
).to(dtype=dtype)
154154

155155
for key in list(original_state_dict.keys()):

src/diffusers/models/embeddings.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(
350350
spatial_interpolation_scale: float = 1.875,
351351
temporal_interpolation_scale: float = 1.0,
352352
use_positional_embeddings: bool = True,
353+
use_learned_positional_embeddings: bool = True,
353354
) -> None:
354355
super().__init__()
355356

@@ -363,15 +364,17 @@ def __init__(
363364
self.spatial_interpolation_scale = spatial_interpolation_scale
364365
self.temporal_interpolation_scale = temporal_interpolation_scale
365366
self.use_positional_embeddings = use_positional_embeddings
367+
self.use_learned_positional_embeddings = use_learned_positional_embeddings
366368

367369
self.proj = nn.Conv2d(
368370
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
369371
)
370372
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
371373

372-
if use_positional_embeddings:
374+
if use_positional_embeddings or use_learned_positional_embeddings:
375+
persistent = use_learned_positional_embeddings
373376
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
374-
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
377+
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
375378

376379
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
377380
post_patch_height = sample_height // self.patch_size
@@ -415,8 +418,15 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
415418
[text_embeds, image_embeds], dim=1
416419
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
417420

418-
if self.use_positional_embeddings:
421+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
422+
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
423+
raise ValueError(
424+
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
425+
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
426+
)
427+
419428
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
429+
420430
if (
421431
self.sample_height != height
422432
or self.sample_width != width

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def __init__(
235235
spatial_interpolation_scale: float = 1.875,
236236
temporal_interpolation_scale: float = 1.0,
237237
use_rotary_positional_embeddings: bool = False,
238+
use_learned_positional_embeddings: bool = False,
238239
):
239240
super().__init__()
240241
inner_dim = num_attention_heads * attention_head_dim
@@ -254,6 +255,7 @@ def __init__(
254255
spatial_interpolation_scale=spatial_interpolation_scale,
255256
temporal_interpolation_scale=temporal_interpolation_scale,
256257
use_positional_embeddings=not use_rotary_positional_embeddings,
258+
use_learned_positional_embeddings=use_learned_positional_embeddings,
257259
)
258260
self.embedding_dropout = nn.Dropout(dropout)
259261

0 commit comments

Comments
 (0)