Skip to content

Commit 67cb373

Browse files
committed
make style
1 parent b033aad commit 67cb373

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def convert_transformer(
140140
use_rotary_positional_embeddings: bool,
141141
i2v: bool,
142142
dtype: torch.dtype,
143-
init_kwargs: Dict[str, Any]
143+
init_kwargs: Dict[str, Any],
144144
):
145145
PREFIX_KEY = "model.diffusion_model."
146146

@@ -165,7 +165,7 @@ def convert_transformer(
165165
if special_key not in key:
166166
continue
167167
handler_fn_inplace(key, original_state_dict)
168-
168+
169169
transformer.load_state_dict(original_state_dict, strict=True)
170170
return transformer
171171

@@ -201,7 +201,7 @@ def get_init_kwargs(version: str):
201201
"sample_width": 720 // vae_scale_factor_spatial,
202202
"sample_frames": 49,
203203
}
204-
204+
205205
elif version == "1.5":
206206
vae_scale_factor_spatial = 8
207207
init_kwargs = {
@@ -214,7 +214,7 @@ def get_init_kwargs(version: str):
214214
}
215215
else:
216216
raise ValueError("Unsupported version of CogVideoX.")
217-
217+
218218
return init_kwargs
219219

220220

@@ -245,8 +245,18 @@ def get_args():
245245
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
246246
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
247247
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
248-
parser.add_argument("--i2v", action="store_true", default=False, help="Whether the model to be converted is the Image-to-Video version of CogVideoX.")
249-
parser.add_argument("--version", choices=["1.0", "1.5"], default="1.0", help="Which version of CogVideoX to use for initializing default modeling parameters.")
248+
parser.add_argument(
249+
"--i2v",
250+
action="store_true",
251+
default=False,
252+
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
253+
)
254+
parser.add_argument(
255+
"--version",
256+
choices=["1.0", "1.5"],
257+
default="1.0",
258+
help="Which version of CogVideoX to use for initializing default modeling parameters.",
259+
)
250260
return parser.parse_args()
251261

252262

src/diffusers/models/embeddings.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import numpy as np
1818
import torch
1919
import torch.nn.functional as F
20-
from einops import rearrange
2120
from torch import nn
2221

2322
from ..utils import deprecate
@@ -377,7 +376,7 @@ def __init__(
377376
else:
378377
# CogVideoX 1.5 checkpoints
379378
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
380-
379+
381380
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
382381

383382
if use_positional_embeddings or use_learned_positional_embeddings:
@@ -429,7 +428,9 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
429428
p_t = self.patch_size_t
430429

431430
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
432-
image_embeds = image_embeds.reshape(batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels)
431+
image_embeds = image_embeds.reshape(
432+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
433+
)
433434
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
434435
image_embeds = self.proj(image_embeds)
435436

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def __init__(
308308
else:
309309
# For CogVideoX 1.5
310310
output_dim = patch_size * patch_size * patch_size_t * out_channels
311-
311+
312312
self.proj_out = nn.Linear(inner_dim, output_dim)
313313

314314
self.gradient_checkpointing = False
@@ -516,7 +516,9 @@ def custom_forward(*inputs):
516516
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
517517
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
518518
else:
519-
output = hidden_states.reshape(batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p)
519+
output = hidden_states.reshape(
520+
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
521+
)
520522
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
521523
output = output[:, remaining_frames:]
522524

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def _prepare_rotary_positional_embeddings(
449449
base_size_width = self.transformer.config.sample_width // p
450450
base_size_height = self.transformer.config.sample_height // p
451451
base_num_frames = (num_frames + p_t - 1) // p_t
452-
452+
453453
grid_crops_coords = get_resize_crop_region_for_grid(
454454
(grid_height, grid_width), base_size_width, base_size_height
455455
)

0 commit comments

Comments
 (0)