|
17 | 17 | import numpy as np |
18 | 18 | import torch |
19 | 19 | import torch.nn.functional as F |
| 20 | +from einops import rearrange |
20 | 21 | from torch import nn |
21 | 22 |
|
22 | 23 | from ..utils import deprecate |
@@ -333,6 +334,122 @@ def forward(self, x, freqs_cis): |
333 | 334 | freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), |
334 | 335 | ) |
335 | 336 |
|
| 337 | +class CogVideoX1_1PatchEmbed(nn.Module): |
| 338 | + def __init__( |
| 339 | + self, |
| 340 | + patch_size: int = 2, |
| 341 | + in_channels: int = 16, |
| 342 | + embed_dim: int = 1920, |
| 343 | + text_embed_dim: int = 4096, |
| 344 | + sample_width: int = 90, |
| 345 | + sample_height: int = 60, |
| 346 | + sample_frames: int = 81, |
| 347 | + temporal_compression_ratio: int = 4, |
| 348 | + max_text_seq_length: int = 226, |
| 349 | + spatial_interpolation_scale: float = 1.875, |
| 350 | + temporal_interpolation_scale: float = 1.0, |
| 351 | + use_positional_embeddings: bool = True, |
| 352 | + use_learned_positional_embeddings: bool = True, |
| 353 | + ) -> None: |
| 354 | + super().__init__() |
| 355 | + |
| 356 | + # Adjust patch_size to handle three dimensions |
| 357 | + self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width) |
| 358 | + self.embed_dim = embed_dim |
| 359 | + self.sample_height = sample_height |
| 360 | + self.sample_width = sample_width |
| 361 | + self.sample_frames = sample_frames |
| 362 | + self.temporal_compression_ratio = temporal_compression_ratio |
| 363 | + self.max_text_seq_length = max_text_seq_length |
| 364 | + self.spatial_interpolation_scale = spatial_interpolation_scale |
| 365 | + self.temporal_interpolation_scale = temporal_interpolation_scale |
| 366 | + self.use_positional_embeddings = use_positional_embeddings |
| 367 | + self.use_learned_positional_embeddings = use_learned_positional_embeddings |
| 368 | + |
| 369 | + # Use Linear layer for projection |
| 370 | + self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim) |
| 371 | + self.text_proj = nn.Linear(text_embed_dim, embed_dim) |
| 372 | + |
| 373 | + if use_positional_embeddings or use_learned_positional_embeddings: |
| 374 | + persistent = use_learned_positional_embeddings |
| 375 | + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) |
| 376 | + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) |
| 377 | + |
| 378 | + def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: |
| 379 | + post_patch_height = sample_height // self.patch_size[1] |
| 380 | + post_patch_width = sample_width // self.patch_size[2] |
| 381 | + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 |
| 382 | + num_patches = post_patch_height * post_patch_width * post_time_compression_frames |
| 383 | + |
| 384 | + pos_embedding = get_3d_sincos_pos_embed( |
| 385 | + self.embed_dim, |
| 386 | + (post_patch_width, post_patch_height), |
| 387 | + post_time_compression_frames, |
| 388 | + self.spatial_interpolation_scale, |
| 389 | + self.temporal_interpolation_scale, |
| 390 | + ) |
| 391 | + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) |
| 392 | + joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False) |
| 393 | + joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) |
| 394 | + |
| 395 | + return joint_pos_embedding |
| 396 | + |
| 397 | + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): |
| 398 | + """ |
| 399 | + Args: |
| 400 | + text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim). |
| 401 | + image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width). |
| 402 | + """ |
| 403 | + text_embeds = self.text_proj(text_embeds) |
| 404 | + first_frame = image_embeds[:, 0:1, :, :, :] |
| 405 | + duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width) |
| 406 | + # Copy the first frames, for t_patch |
| 407 | + image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1) |
| 408 | + batch, num_frames, channels, height, width = image_embeds.shape |
| 409 | + image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous() |
| 410 | + image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1) |
| 411 | + |
| 412 | + rope_patch_t = num_frames // self.patch_size[0] |
| 413 | + rope_patch_h = height // self.patch_size[1] |
| 414 | + rope_patch_w = width // self.patch_size[2] |
| 415 | + |
| 416 | + image_embeds = image_embeds.view( |
| 417 | + batch, |
| 418 | + rope_patch_t, self.patch_size[0], |
| 419 | + rope_patch_h, self.patch_size[1], |
| 420 | + rope_patch_w, self.patch_size[2], |
| 421 | + channels |
| 422 | + ) |
| 423 | + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() |
| 424 | + image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1) |
| 425 | + image_embeds = self.proj(image_embeds) |
| 426 | + # Concatenate text and image embeddings |
| 427 | + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() |
| 428 | + |
| 429 | + # Add positional embeddings if applicable |
| 430 | + if self.use_positional_embeddings or self.use_learned_positional_embeddings: |
| 431 | + if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): |
| 432 | + raise ValueError( |
| 433 | + "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." |
| 434 | + "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." |
| 435 | + ) |
| 436 | + |
| 437 | + pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 |
| 438 | + |
| 439 | + if ( |
| 440 | + self.sample_height != height |
| 441 | + or self.sample_width != width |
| 442 | + or self.sample_frames != pre_time_compression_frames |
| 443 | + ): |
| 444 | + pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) |
| 445 | + pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) |
| 446 | + else: |
| 447 | + pos_embedding = self.pos_embedding |
| 448 | + |
| 449 | + embeds = embeds + pos_embedding |
| 450 | + |
| 451 | + return embeds |
| 452 | + |
336 | 453 |
|
337 | 454 | class CogVideoXPatchEmbed(nn.Module): |
338 | 455 | def __init__( |
|
0 commit comments