Skip to content

Commit 0583a8d

Browse files
authored
Make CogVideoX RoPE implementation consistent (huggingface#9963)
* update cogvideox rope implementation * apply suggestions from review
1 parent 7d0b9c4 commit 0583a8d

File tree

4 files changed

+78
-40
lines changed

4 files changed

+78
-40
lines changed

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -444,21 +444,34 @@ def _prepare_rotary_positional_embeddings(
444444
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
445445

446446
p = self.transformer.config.patch_size
447-
p_t = self.transformer.config.patch_size_t or 1
447+
p_t = self.transformer.config.patch_size_t
448448

449449
base_size_width = self.transformer.config.sample_width // p
450450
base_size_height = self.transformer.config.sample_height // p
451-
base_num_frames = (num_frames + p_t - 1) // p_t
452451

453-
grid_crops_coords = get_resize_crop_region_for_grid(
454-
(grid_height, grid_width), base_size_width, base_size_height
455-
)
456-
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
457-
embed_dim=self.transformer.config.attention_head_dim,
458-
crops_coords=grid_crops_coords,
459-
grid_size=(grid_height, grid_width),
460-
temporal_size=base_num_frames,
461-
)
452+
if p_t is None:
453+
# CogVideoX 1.0
454+
grid_crops_coords = get_resize_crop_region_for_grid(
455+
(grid_height, grid_width), base_size_width, base_size_height
456+
)
457+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
458+
embed_dim=self.transformer.config.attention_head_dim,
459+
crops_coords=grid_crops_coords,
460+
grid_size=(grid_height, grid_width),
461+
temporal_size=num_frames,
462+
)
463+
else:
464+
# CogVideoX 1.5
465+
base_num_frames = (num_frames + p_t - 1) // p_t
466+
467+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
468+
embed_dim=self.transformer.config.attention_head_dim,
469+
crops_coords=None,
470+
grid_size=(grid_height, grid_width),
471+
temporal_size=base_num_frames,
472+
grid_type="slice",
473+
max_size=(base_size_height, base_size_width),
474+
)
462475

463476
freqs_cos = freqs_cos.to(device=device)
464477
freqs_sin = freqs_sin.to(device=device)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -490,21 +490,34 @@ def _prepare_rotary_positional_embeddings(
490490
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
491491

492492
p = self.transformer.config.patch_size
493-
p_t = self.transformer.config.patch_size_t or 1
493+
p_t = self.transformer.config.patch_size_t
494494

495495
base_size_width = self.transformer.config.sample_width // p
496496
base_size_height = self.transformer.config.sample_height // p
497-
base_num_frames = (num_frames + p_t - 1) // p_t
498497

499-
grid_crops_coords = get_resize_crop_region_for_grid(
500-
(grid_height, grid_width), base_size_width, base_size_height
501-
)
502-
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
503-
embed_dim=self.transformer.config.attention_head_dim,
504-
crops_coords=grid_crops_coords,
505-
grid_size=(grid_height, grid_width),
506-
temporal_size=base_num_frames,
507-
)
498+
if p_t is None:
499+
# CogVideoX 1.0
500+
grid_crops_coords = get_resize_crop_region_for_grid(
501+
(grid_height, grid_width), base_size_width, base_size_height
502+
)
503+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
504+
embed_dim=self.transformer.config.attention_head_dim,
505+
crops_coords=grid_crops_coords,
506+
grid_size=(grid_height, grid_width),
507+
temporal_size=num_frames,
508+
)
509+
else:
510+
# CogVideoX 1.5
511+
base_num_frames = (num_frames + p_t - 1) // p_t
512+
513+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
514+
embed_dim=self.transformer.config.attention_head_dim,
515+
crops_coords=None,
516+
grid_size=(grid_height, grid_width),
517+
temporal_size=base_num_frames,
518+
grid_type="slice",
519+
max_size=(base_size_height, base_size_width),
520+
)
508521

509522
freqs_cos = freqs_cos.to(device=device)
510523
freqs_sin = freqs_sin.to(device=device)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ def unfuse_qkv_projections(self) -> None:
528528
self.transformer.unfuse_qkv_projections()
529529
self.fusing_transformer = False
530530

531+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
531532
def _prepare_rotary_positional_embeddings(
532533
self,
533534
height: int,
@@ -541,11 +542,11 @@ def _prepare_rotary_positional_embeddings(
541542
p = self.transformer.config.patch_size
542543
p_t = self.transformer.config.patch_size_t
543544

544-
if p_t is None:
545-
# CogVideoX 1.0 I2V
546-
base_size_width = self.transformer.config.sample_width // p
547-
base_size_height = self.transformer.config.sample_height // p
545+
base_size_width = self.transformer.config.sample_width // p
546+
base_size_height = self.transformer.config.sample_height // p
548547

548+
if p_t is None:
549+
# CogVideoX 1.0
549550
grid_crops_coords = get_resize_crop_region_for_grid(
550551
(grid_height, grid_width), base_size_width, base_size_height
551552
)
@@ -556,9 +557,7 @@ def _prepare_rotary_positional_embeddings(
556557
temporal_size=num_frames,
557558
)
558559
else:
559-
# CogVideoX 1.5 I2V
560-
base_size_width = self.transformer.config.sample_width // p
561-
base_size_height = self.transformer.config.sample_height // p
560+
# CogVideoX 1.5
562561
base_num_frames = (num_frames + p_t - 1) // p_t
563562

564563
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -520,21 +520,34 @@ def _prepare_rotary_positional_embeddings(
520520
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
521521

522522
p = self.transformer.config.patch_size
523-
p_t = self.transformer.config.patch_size_t or 1
523+
p_t = self.transformer.config.patch_size_t
524524

525525
base_size_width = self.transformer.config.sample_width // p
526526
base_size_height = self.transformer.config.sample_height // p
527-
base_num_frames = (num_frames + p_t - 1) // p_t
528527

529-
grid_crops_coords = get_resize_crop_region_for_grid(
530-
(grid_height, grid_width), base_size_width, base_size_height
531-
)
532-
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
533-
embed_dim=self.transformer.config.attention_head_dim,
534-
crops_coords=grid_crops_coords,
535-
grid_size=(grid_height, grid_width),
536-
temporal_size=base_num_frames,
537-
)
528+
if p_t is None:
529+
# CogVideoX 1.0
530+
grid_crops_coords = get_resize_crop_region_for_grid(
531+
(grid_height, grid_width), base_size_width, base_size_height
532+
)
533+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
534+
embed_dim=self.transformer.config.attention_head_dim,
535+
crops_coords=grid_crops_coords,
536+
grid_size=(grid_height, grid_width),
537+
temporal_size=num_frames,
538+
)
539+
else:
540+
# CogVideoX 1.5
541+
base_num_frames = (num_frames + p_t - 1) // p_t
542+
543+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
544+
embed_dim=self.transformer.config.attention_head_dim,
545+
crops_coords=None,
546+
grid_size=(grid_height, grid_width),
547+
temporal_size=base_num_frames,
548+
grid_type="slice",
549+
max_size=(base_size_height, base_size_width),
550+
)
538551

539552
freqs_cos = freqs_cos.to(device=device)
540553
freqs_sin = freqs_sin.to(device=device)

0 commit comments

Comments
 (0)