Skip to content

Commit b87b07e

Browse files
add ofs embed(for convert)
1 parent d833f72 commit b87b07e

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
8080
"post_attn1_layernorm": "norm2.norm",
8181
"time_embed.0": "time_embedding.linear_1",
8282
"time_embed.2": "time_embedding.linear_2",
83+
"ofs_embed.0": "ofs_embedding.linear_1",
84+
"ofs_embed.2": "ofs_embedding.linear_2",
8385
"mixins.patch_embed": "patch_embed",
8486
"mixins.final_layer.norm_final": "norm_out.norm",
8587
"mixins.final_layer.linear": "proj_out",
@@ -150,7 +152,8 @@ def convert_transformer(
150152
num_layers=num_layers,
151153
num_attention_heads=num_attention_heads,
152154
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
153-
use_learned_positional_embeddings=i2v,
155+
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
156+
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
154157
**init_kwargs,
155158
).to(dtype=dtype)
156159

@@ -210,7 +213,7 @@ def get_init_kwargs(version: str):
210213
"patch_bias": False,
211214
"sample_height": 768 // vae_scale_factor_spatial,
212215
"sample_width": 1360 // vae_scale_factor_spatial,
213-
"sample_frames": 81,
216+
"sample_frames": 81, # TODO: Need Test with 161 for 10 seconds
214217
}
215218
else:
216219
raise ValueError("Unsupported version of CogVideoX.")

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def __init__(
219219
flip_sin_to_cos: bool = True,
220220
freq_shift: int = 0,
221221
time_embed_dim: int = 512,
222+
ofs_embed_dim: Optional[int] = 512,
222223
text_embed_dim: int = 4096,
223224
num_layers: int = 30,
224225
dropout: float = 0.0,
@@ -270,10 +271,15 @@ def __init__(
270271
)
271272
self.embedding_dropout = nn.Dropout(dropout)
272273

273-
# 2. Time embeddings
274+
# 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
275+
274276
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
275277
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
276278

279+
if ofs_embed_dim:
280+
self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs
281+
self.ofs_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
282+
277283
# 3. Define spatio-temporal transformers blocks
278284
self.transformer_blocks = nn.ModuleList(
279285
[

0 commit comments

Comments
 (0)