Skip to content

Commit b033aad

Browse files
committed
refactor
1 parent 87535d6 commit b033aad

File tree

4 files changed

+110
-149
lines changed

4 files changed

+110
-149
lines changed

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def convert_transformer(
140140
use_rotary_positional_embeddings: bool,
141141
i2v: bool,
142142
dtype: torch.dtype,
143+
init_kwargs: Dict[str, Any]
143144
):
144145
PREFIX_KEY = "model.diffusion_model."
145146

@@ -150,6 +151,7 @@ def convert_transformer(
150151
num_attention_heads=num_attention_heads,
151152
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
152153
use_learned_positional_embeddings=i2v,
154+
**init_kwargs,
153155
).to(dtype=dtype)
154156

155157
for key in list(original_state_dict.keys()):
@@ -163,6 +165,7 @@ def convert_transformer(
163165
if special_key not in key:
164166
continue
165167
handler_fn_inplace(key, original_state_dict)
168+
166169
transformer.load_state_dict(original_state_dict, strict=True)
167170
return transformer
168171

@@ -187,6 +190,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
187190
return vae
188191

189192

193+
def get_init_kwargs(version: str):
194+
if version == "1.0":
195+
vae_scale_factor_spatial = 8
196+
init_kwargs = {
197+
"patch_size": 2,
198+
"patch_size_t": None,
199+
"patch_bias": True,
200+
"sample_height": 480 // vae_scale_factor_spatial,
201+
"sample_width": 720 // vae_scale_factor_spatial,
202+
"sample_frames": 49,
203+
}
204+
205+
elif version == "1.5":
206+
vae_scale_factor_spatial = 8
207+
init_kwargs = {
208+
"patch_size": 2,
209+
"patch_size_t": 2,
210+
"patch_bias": False,
211+
"sample_height": 768 // vae_scale_factor_spatial,
212+
"sample_width": 1360 // vae_scale_factor_spatial,
213+
"sample_frames": 81,
214+
}
215+
else:
216+
raise ValueError("Unsupported version of CogVideoX.")
217+
218+
return init_kwargs
219+
220+
190221
def get_args():
191222
parser = argparse.ArgumentParser()
192223
parser.add_argument(
@@ -214,7 +245,8 @@ def get_args():
214245
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
215246
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
216247
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
217-
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
248+
parser.add_argument("--i2v", action="store_true", default=False, help="Whether the model to be converted is the Image-to-Video version of CogVideoX.")
249+
parser.add_argument("--version", choices=["1.0", "1.5"], default="1.0", help="Which version of CogVideoX to use for initializing default modeling parameters.")
218250
return parser.parse_args()
219251

220252

@@ -230,18 +262,20 @@ def get_args():
230262
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
231263

232264
if args.transformer_ckpt_path is not None:
265+
init_kwargs = get_init_kwargs(args.version)
233266
transformer = convert_transformer(
234267
args.transformer_ckpt_path,
235268
args.num_layers,
236269
args.num_attention_heads,
237270
args.use_rotary_positional_embeddings,
238271
args.i2v,
239272
dtype,
273+
init_kwargs,
240274
)
241275
if args.vae_ckpt_path is not None:
242276
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
243277

244-
text_encoder_id = "/share/home/zyx/Models/CogVideoX1.1-5B-SAT/t5-v1_1-xxl"
278+
text_encoder_id = "google/t5-v1_1-xxl"
245279
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
246280
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
247281

src/diffusers/models/embeddings.py

Lines changed: 27 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -334,127 +334,12 @@ def forward(self, x, freqs_cis):
334334
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
335335
)
336336

337-
class CogVideoX1_1PatchEmbed(nn.Module):
338-
def __init__(
339-
self,
340-
patch_size: int = 2,
341-
in_channels: int = 16,
342-
embed_dim: int = 1920,
343-
text_embed_dim: int = 4096,
344-
sample_width: int = 90,
345-
sample_height: int = 60,
346-
sample_frames: int = 81,
347-
temporal_compression_ratio: int = 4,
348-
max_text_seq_length: int = 226,
349-
spatial_interpolation_scale: float = 1.875,
350-
temporal_interpolation_scale: float = 1.0,
351-
use_positional_embeddings: bool = True,
352-
use_learned_positional_embeddings: bool = True,
353-
) -> None:
354-
super().__init__()
355-
356-
# Adjust patch_size to handle three dimensions
357-
self.patch_size = (patch_size, patch_size, patch_size) # (depth, height, width)
358-
self.embed_dim = embed_dim
359-
self.sample_height = sample_height
360-
self.sample_width = sample_width
361-
self.sample_frames = sample_frames
362-
self.temporal_compression_ratio = temporal_compression_ratio
363-
self.max_text_seq_length = max_text_seq_length
364-
self.spatial_interpolation_scale = spatial_interpolation_scale
365-
self.temporal_interpolation_scale = temporal_interpolation_scale
366-
self.use_positional_embeddings = use_positional_embeddings
367-
self.use_learned_positional_embeddings = use_learned_positional_embeddings
368-
369-
# Use Linear layer for projection
370-
self.proj = nn.Linear(in_channels * (patch_size ** 3), embed_dim)
371-
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
372-
373-
if use_positional_embeddings or use_learned_positional_embeddings:
374-
persistent = use_learned_positional_embeddings
375-
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
376-
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
377-
378-
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
379-
post_patch_height = sample_height // self.patch_size[1]
380-
post_patch_width = sample_width // self.patch_size[2]
381-
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
382-
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
383-
384-
pos_embedding = get_3d_sincos_pos_embed(
385-
self.embed_dim,
386-
(post_patch_width, post_patch_height),
387-
post_time_compression_frames,
388-
self.spatial_interpolation_scale,
389-
self.temporal_interpolation_scale,
390-
)
391-
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
392-
joint_pos_embedding = torch.zeros(1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False)
393-
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
394-
395-
return joint_pos_embedding
396-
397-
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
398-
"""
399-
Args:
400-
text_embeds (torch.Tensor): Input text embeddings of shape (batch_size, seq_length, embedding_dim).
401-
image_embeds (torch.Tensor): Input image embeddings of shape (batch_size, num_frames, channels, height, width).
402-
"""
403-
text_embeds = self.text_proj(text_embeds)
404-
first_frame = image_embeds[:, 0:1, :, :, :]
405-
duplicated_first_frame = first_frame.repeat(1, 2, 1, 1, 1) # (batch, 2, channels, height, width)
406-
# Copy the first frames, for t_patch
407-
image_embeds = torch.cat([duplicated_first_frame, image_embeds[:, 1:, :, :, :]], dim=1)
408-
batch, num_frames, channels, height, width = image_embeds.shape
409-
image_embeds = image_embeds.permute(0, 2, 1, 3, 4).contiguous()
410-
image_embeds = image_embeds.view(batch, channels, -1).permute(0, 2, 1)
411-
412-
rope_patch_t = num_frames // self.patch_size[0]
413-
rope_patch_h = height // self.patch_size[1]
414-
rope_patch_w = width // self.patch_size[2]
415-
416-
image_embeds = image_embeds.view(
417-
batch,
418-
rope_patch_t, self.patch_size[0],
419-
rope_patch_h, self.patch_size[1],
420-
rope_patch_w, self.patch_size[2],
421-
channels
422-
)
423-
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
424-
image_embeds = image_embeds.view(batch, rope_patch_t * rope_patch_h * rope_patch_w, -1)
425-
image_embeds = self.proj(image_embeds)
426-
# Concatenate text and image embeddings
427-
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
428-
429-
# Add positional embeddings if applicable
430-
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
431-
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
432-
raise ValueError(
433-
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
434-
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
435-
)
436-
437-
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
438-
439-
if (
440-
self.sample_height != height
441-
or self.sample_width != width
442-
or self.sample_frames != pre_time_compression_frames
443-
):
444-
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
445-
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
446-
else:
447-
pos_embedding = self.pos_embedding
448-
449-
embeds = embeds + pos_embedding
450-
451-
return embeds
452-
453337

454338
class CogVideoXPatchEmbed(nn.Module):
455339
def __init__(
456340
self,
457341
patch_size: int = 2,
342+
patch_size_t: Optional[int] = None,
458343
in_channels: int = 16,
459344
embed_dim: int = 1920,
460345
text_embed_dim: int = 4096,
@@ -472,6 +357,7 @@ def __init__(
472357
super().__init__()
473358

474359
self.patch_size = patch_size
360+
self.patch_size_t = patch_size_t
475361
self.embed_dim = embed_dim
476362
self.sample_height = sample_height
477363
self.sample_width = sample_width
@@ -483,9 +369,15 @@ def __init__(
483369
self.use_positional_embeddings = use_positional_embeddings
484370
self.use_learned_positional_embeddings = use_learned_positional_embeddings
485371

486-
self.proj = nn.Conv2d(
487-
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
488-
)
372+
if patch_size_t is None:
373+
# CogVideoX 1.0 checkpoints
374+
self.proj = nn.Conv2d(
375+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
376+
)
377+
else:
378+
# CogVideoX 1.5 checkpoints
379+
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
380+
489381
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
490382

491383
if use_positional_embeddings or use_learned_positional_embeddings:
@@ -524,12 +416,22 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
524416
"""
525417
text_embeds = self.text_proj(text_embeds)
526418

527-
batch, num_frames, channels, height, width = image_embeds.shape
528-
image_embeds = image_embeds.reshape(-1, channels, height, width)
529-
image_embeds = self.proj(image_embeds)
530-
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
531-
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
532-
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
419+
batch_size, num_frames, channels, height, width = image_embeds.shape
420+
421+
if self.patch_size_t is None:
422+
image_embeds = image_embeds.reshape(-1, channels, height, width)
423+
image_embeds = self.proj(image_embeds)
424+
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
425+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
426+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
427+
else:
428+
p = self.patch_size
429+
p_t = self.patch_size_t
430+
431+
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
432+
image_embeds = image_embeds.reshape(batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels)
433+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
434+
image_embeds = self.proj(image_embeds)
533435

534436
embeds = torch.cat(
535437
[text_embeds, image_embeds], dim=1

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...utils.torch_utils import maybe_allow_in_graph
2525
from ..attention import Attention, FeedForward
2626
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
27-
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, CogVideoX1_1PatchEmbed
27+
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
2828
from ..modeling_outputs import Transformer2DModelOutput
2929
from ..modeling_utils import ModelMixin
3030
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -227,6 +227,7 @@ def __init__(
227227
sample_height: int = 60,
228228
sample_frames: int = 49,
229229
patch_size: int = 2,
230+
patch_size_t: int = 2,
230231
temporal_compression_ratio: int = 4,
231232
max_text_seq_length: int = 226,
232233
activation_fn: str = "gelu-approximate",
@@ -237,6 +238,7 @@ def __init__(
237238
temporal_interpolation_scale: float = 1.0,
238239
use_rotary_positional_embeddings: bool = False,
239240
use_learned_positional_embeddings: bool = False,
241+
patch_bias: bool = True,
240242
):
241243
super().__init__()
242244
inner_dim = num_attention_heads * attention_head_dim
@@ -249,15 +251,13 @@ def __init__(
249251
)
250252

251253
# 1. Patch embedding
252-
#TODO: different git push --set-upstream origin cogvideox1.1-5b
253-
254-
# self.patch_embed = CogVideoXPatchEmbed(
255-
self.patch_embed = CogVideoX1_1PatchEmbed(
254+
self.patch_embed = CogVideoXPatchEmbed(
256255
patch_size=patch_size,
256+
patch_size_t=patch_size_t,
257257
in_channels=in_channels,
258258
embed_dim=inner_dim,
259259
text_embed_dim=text_embed_dim,
260-
# bias=True, # Only using in CogVideoX-5B
260+
bias=patch_bias,
261261
sample_width=sample_width,
262262
sample_height=sample_height,
263263
sample_frames=sample_frames,
@@ -301,7 +301,15 @@ def __init__(
301301
norm_eps=norm_eps,
302302
chunk_dim=1,
303303
)
304-
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * patch_size * out_channels) # For CogVideoX1.1-5B
304+
305+
if patch_size_t is None:
306+
# For CogVideox 1.0
307+
output_dim = patch_size * patch_size * out_channels
308+
else:
309+
# For CogVideoX 1.5
310+
output_dim = patch_size * patch_size * patch_size_t * out_channels
311+
312+
self.proj_out = nn.Linear(inner_dim, output_dim)
305313

306314
self.gradient_checkpointing = False
307315

@@ -446,6 +454,16 @@ def forward(
446454
emb = self.time_embedding(t_emb, timestep_cond)
447455

448456
# 2. Patch embedding
457+
p = self.config.patch_size
458+
p_t = self.config.patch_size_t
459+
460+
# We know that the hidden states height and width will always be divisible by patch_size.
461+
# But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames.
462+
if p_t is not None:
463+
remaining_frames = p_t - num_frames % p_t
464+
first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
465+
hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
466+
449467
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
450468
hidden_states = self.embedding_dropout(hidden_states)
451469

@@ -494,17 +512,18 @@ def custom_forward(*inputs):
494512
hidden_states = self.proj_out(hidden_states)
495513

496514
# 5. Unpatchify
497-
# Note: we use `-1` instead of `channels`:
498-
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
499-
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
500-
p = self.config.patch_size
501-
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
502-
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
515+
if p_t is None:
516+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
517+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
518+
else:
519+
output = hidden_states.reshape(batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p)
520+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
521+
output = output[:, remaining_frames:]
503522

504523
if USE_PEFT_BACKEND:
505524
# remove `lora_scale` from each PEFT layer
506525
unscale_lora_layers(self, lora_scale)
507526

508527
if not return_dict:
509528
return (output,)
510-
return Transformer2DModelOutput(sample=output)
529+
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)