Skip to content

Commit b02915b

Browse files
CogVideoX1_1PatchEmbed test
1 parent 76b7d86 commit b02915b

File tree

4 files changed

+129
-14
lines changed

4 files changed

+129
-14
lines changed

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def get_args():
241241
if args.vae_ckpt_path is not None:
242242
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
243243

244-
text_encoder_id = "google/t5-v1_1-xxl"
244+
text_encoder_id = "/share/home/zyx/Models/CogVideoX1.1-5B-SAT/t5-v1_1-xxl"
245245
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
246246
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
247247

src/diffusers/models/embeddings.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import torch
1919
import torch.nn.functional as F
20+
from einops import rearrange
2021
from torch import nn
2122

2223
from ..utils import deprecate
@@ -333,6 +334,122 @@ def forward(self, x, freqs_cis):
333334
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
334335
)
335336

337+
class CogVideoX1_1PatchEmbed(nn.Module):
338+
def __init__(
339+
self,
340+
patch_size: int = 2,
341+
in_channels: int = 16,
342+
embed_dim: int = 1920,
343+
text_embed_dim: int = 4096,
344+
sample_width: int = 90,
345+
sample_height: int = 60,
346+
sample_frames: int = 81,
347+
temporal_compression_ratio: int = 4,
348+
max_text_seq_length: int = 226,
349+
spatial_interpolation_scale: float = 1.875,
350+
temporal_interpolation_scale: float = 1.0,
351+
use_positional_embeddings: bool = True,
352+
use_learned_positional_embeddings: bool = True,
353+
) -> None:
354+
super().__init__()
355+
356+
# Adjust patch_size to handle three dimensions
357+
self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width)
358+
self.embed_dim = embed_dim
359+
self.sample_height = sample_height
360+
self.sample_width = sample_width
361+
self.sample_frames = sample_frames
362+
self.temporal_compression_ratio = temporal_compression_ratio
363+
self.max_text_seq_length = max_text_seq_length
364+
self.spatial_interpolation_scale = spatial_interpolation_scale
365+
self.temporal_interpolation_scale = temporal_interpolation_scale
366+
self.use_positional_embeddings = use_positional_embeddings
367+
self.use_learned_positional_embeddings = use_learned_positional_embeddings
368+
369+
# Use Linear layer for projection
370+
self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim)
371+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
372+
373+
if use_positional_embeddings or use_learned_positional_embeddings:
374+
persistent = use_learned_positional_embeddings
375+
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
376+
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
377+
378+
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
379+
post_patch_height = sample_height // self.patch_size[1]
380+
post_patch_width = sample_width // self.patch_size[2]
381+
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
382+
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
383+
384+
pos_embedding = get_3d_sincos_pos_embed(
385+
self.embed_dim,
386+
(post_patch_width, post_patch_height),
387+
post_time_compression_frames,
388+
self.spatial_interpolation_scale,
389+
self.temporal_interpolation_scale,
390+
)
391+
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
392+
joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False)
393+
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
394+
395+
return joint_pos_embedding
396+
397+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
398+
"""
399+
Args:
400+
text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim).
401+
image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width).
402+
"""
403+
text_embeds = self.text_proj(text_embeds)
404+
first_frame = image_embeds[:, 0:1, :, :, :]
405+
duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width)
406+
# Copy the first frames, for t_patch
407+
image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1)
408+
batch, num_frames, channels, height, width = image_embeds.shape
409+
image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous()
410+
image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1)
411+
412+
rope_patch_t = num_frames // self.patch_size[0]
413+
rope_patch_h = height // self.patch_size[1]
414+
rope_patch_w = width // self.patch_size[2]
415+
416+
image_embeds = image_embeds.view(
417+
batch,
418+
rope_patch_t, self.patch_size[0],
419+
rope_patch_h, self.patch_size[1],
420+
rope_patch_w, self.patch_size[2],
421+
channels
422+
)
423+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
424+
image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1)
425+
image_embeds = self.proj(image_embeds)
426+
# Concatenate text and image embeddings
427+
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
428+
429+
# Add positional embeddings if applicable
430+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
431+
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
432+
raise ValueError(
433+
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
434+
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
435+
)
436+
437+
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
438+
439+
if (
440+
self.sample_height != height
441+
or self.sample_width != width
442+
or self.sample_frames != pre_time_compression_frames
443+
):
444+
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
445+
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
446+
else:
447+
pos_embedding = self.pos_embedding
448+
449+
embeds = embeds + pos_embedding
450+
451+
return embeds
452+
336453

337454
class CogVideoXPatchEmbed(nn.Module):
338455
def __init__(

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...utils.torch_utils import maybe_allow_in_graph
2525
from ..attention import Attention, FeedForward
2626
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
27-
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
27+
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, CogVideoX1_1PatchEmbed
2828
from ..modeling_outputs import Transformer2DModelOutput
2929
from ..modeling_utils import ModelMixin
3030
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -249,12 +249,13 @@ def __init__(
249249
)
250250

251251
# 1. Patch embedding
252-
self.patch_embed = CogVideoXPatchEmbed(
252+
# self.patch_embed = CogVideoXPatchEmbed(
253+
self.patch_embed = CogVideoX1_1PatchEmbed(
253254
patch_size=patch_size,
254255
in_channels=in_channels,
255256
embed_dim=inner_dim,
256257
text_embed_dim=text_embed_dim,
257-
bias=True,
258+
# bias=True,
258259
sample_width=sample_width,
259260
sample_height=sample_height,
260261
sample_frames=sample_frames,
@@ -298,7 +299,7 @@ def __init__(
298299
norm_eps=norm_eps,
299300
chunk_dim=1,
300301
)
301-
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
302+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * patch_size * out_channels) # For CogVideoX1.1-5B
302303

303304
self.gradient_checkpointing = False
304305

@@ -504,4 +505,4 @@ def custom_forward(*inputs):
504505

505506
if not return_dict:
506507
return (output,)
507-
return Transformer2DModelOutput(sample=output)
508+
return Transformer2DModelOutput(sample=output)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,11 @@ def _prepare_rotary_positional_embeddings(
442442
) -> Tuple[torch.Tensor, torch.Tensor]:
443443
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
444444
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
445-
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
446-
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
445+
446+
# TODO: Here, compatibility is needed for both the CogVideoX-5B and CogVideoX1.1-5B models.
447+
# CogVideoX1.0 is 720 X 480 and CogVideoX1.1-5B T2V is 768 * 1360, CogVideoX1.1-5B I2V use with image
448+
base_size_width = 768 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
449+
base_size_height = 1360 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
447450

448451
grid_crops_coords = get_resize_crop_region_for_grid(
449452
(grid_height, grid_width), base_size_width, base_size_height
@@ -583,11 +586,6 @@ def __call__(
583586
`tuple`. When returning a tuple, the first element is a list with the generated images.
584587
"""
585588

586-
if num_frames > 49:
587-
raise ValueError(
588-
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
589-
)
590-
591589
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
592590
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
593591

@@ -679,7 +677,6 @@ def __call__(
679677

680678
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
681679
timestep = t.expand(latent_model_input.shape[0])
682-
683680
# predict noise model_output
684681
noise_pred = self.transformer(
685682
hidden_states=latent_model_input,

0 commit comments

Comments
 (0)