@@ -170,14 +170,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
170170 Whether to flip the sin to cos in the time embedding.
171171 time_embed_dim (`int`, defaults to `512`):
172172 Output dimension of timestep embeddings.
173+ ofs_embed_dim (`int`, defaults to `512`):
174+ scaling factor in the VAE process for the Image-to-Video (I2V) transformation in CogVideoX1.5-5B.
173175 text_embed_dim (`int`, defaults to `4096`):
174176 Input dimension of text embeddings from the text encoder.
175177 num_layers (`int`, defaults to `30`):
176178 The number of layers of Transformer blocks to use.
177179 dropout (`float`, defaults to `0.0`):
178180 The dropout probability to use.
179181 attention_bias (`bool`, defaults to `True`):
180- Whether or not to use bias in the attention projection layers.
182+ Whether to use bias in the attention projection layers.
181183 sample_width (`int`, defaults to `90`):
182184 The width of the input latents.
183185 sample_height (`int`, defaults to `60`):
@@ -198,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
198200 timestep_activation_fn (`str`, defaults to `"silu"`):
199201 Activation function to use when generating the timestep embeddings.
200202 norm_elementwise_affine (`bool`, defaults to `True`):
201- Whether or not to use elementwise affine in normalization layers.
203+ Whether to use elementwise affine in normalization layers.
202204 norm_eps (`float`, defaults to `1e-5`):
203205 The epsilon value to use in normalization layers.
204206 spatial_interpolation_scale (`float`, defaults to `1.875`):
@@ -219,7 +221,7 @@ def __init__(
219221 flip_sin_to_cos : bool = True ,
220222 freq_shift : int = 0 ,
221223 time_embed_dim : int = 512 ,
222- ofs_embed_dim : Optional [int ] = 512 ,
224+ ofs_embed_dim : Optional [int ] = None ,
223225 text_embed_dim : int = 4096 ,
224226 num_layers : int = 30 ,
225227 dropout : float = 0.0 ,
@@ -276,9 +278,10 @@ def __init__(
276278 self .time_proj = Timesteps (inner_dim , flip_sin_to_cos , freq_shift )
277279 self .time_embedding = TimestepEmbedding (inner_dim , time_embed_dim , timestep_activation_fn )
278280
281+ self .ofs_embedding = None
282+
279283 if ofs_embed_dim :
280284 self .ofs_embedding = TimestepEmbedding (ofs_embed_dim , ofs_embed_dim , timestep_activation_fn ) # same as time embeddings, for ofs
281- self .ofs_proj = Timesteps (inner_dim , flip_sin_to_cos , freq_shift )
282285
283286 # 3. Define spatio-temporal transformers blocks
284287 self .transformer_blocks = nn .ModuleList (
@@ -458,6 +461,9 @@ def forward(
458461 # there might be better ways to encapsulate this.
459462 t_emb = t_emb .to (dtype = hidden_states .dtype )
460463 emb = self .time_embedding (t_emb , timestep_cond )
464+ if self .ofs_embedding is not None :
465+ emb_ofs = self .ofs_embedding (emb , timestep_cond )
466+ emb = emb + emb_ofs
461467
462468 # 2. Patch embedding
463469 p = self .config .patch_size
0 commit comments