Skip to content

Commit f2213e8

Browse files
committed
fix ofs_embed
1 parent 8966cb0 commit f2213e8

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,10 @@ def __init__(
278278
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
279279
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
280280

281+
self.ofs_proj = None
281282
self.ofs_embedding = None
282-
283283
if ofs_embed_dim:
284+
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
284285
self.ofs_embedding = TimestepEmbedding(
285286
ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
286287
) # same as time embeddings, for ofs
@@ -433,6 +434,7 @@ def forward(
433434
encoder_hidden_states: torch.Tensor,
434435
timestep: Union[int, float, torch.LongTensor],
435436
timestep_cond: Optional[torch.Tensor] = None,
437+
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
436438
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
437439
attention_kwargs: Optional[Dict[str, Any]] = None,
438440
return_dict: bool = True,
@@ -463,9 +465,12 @@ def forward(
463465
# there might be better ways to encapsulate this.
464466
t_emb = t_emb.to(dtype=hidden_states.dtype)
465467
emb = self.time_embedding(t_emb, timestep_cond)
468+
466469
if self.ofs_embedding is not None:
467-
emb_ofs = self.ofs_embedding(emb, timestep_cond)
468-
emb = emb + emb_ofs
470+
ofs_emb = self.ofs_proj(ofs)
471+
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
472+
ofs_emb = self.ofs_embedding(ofs_emb)
473+
emb = emb + ofs_emb
469474

470475
# 2. Patch embedding
471476
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,13 +769,17 @@ def __call__(
769769

770770
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
771771
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
772+
772773
# 7. Create rotary embeds if required
773774
image_rotary_emb = (
774775
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
775776
if self.transformer.config.use_rotary_positional_embeddings
776777
else None
777778
)
778779

780+
# 8. Create ofs embeds if required
781+
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
782+
779783
# 8. Denoising loop
780784
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
781785

@@ -800,6 +804,7 @@ def __call__(
800804
hidden_states=latent_model_input,
801805
encoder_hidden_states=prompt_embeds,
802806
timestep=timestep,
807+
ofs=ofs_emb,
803808
image_rotary_emb=image_rotary_emb,
804809
attention_kwargs=attention_kwargs,
805810
return_dict=False,

0 commit comments

Comments
 (0)