Skip to content

Commit d833f72

Browse files
committed
make fix-copies
1 parent ea56788 commit d833f72

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,13 @@ def _prepare_rotary_positional_embeddings(
488488
) -> Tuple[torch.Tensor, torch.Tensor]:
489489
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
490490
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
491-
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
492-
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
491+
492+
p = self.transformer.config.patch_size
493+
p_t = self.transformer.config.patch_size_t or 1
494+
495+
base_size_width = self.transformer.config.sample_width // p
496+
base_size_height = self.transformer.config.sample_height // p
497+
base_num_frames = (num_frames + p_t - 1) // p_t
493498

494499
grid_crops_coords = get_resize_crop_region_for_grid(
495500
(grid_height, grid_width), base_size_width, base_size_height
@@ -498,7 +503,7 @@ def _prepare_rotary_positional_embeddings(
498503
embed_dim=self.transformer.config.attention_head_dim,
499504
crops_coords=grid_crops_coords,
500505
grid_size=(grid_height, grid_width),
501-
temporal_size=num_frames,
506+
temporal_size=base_num_frames,
502507
)
503508

504509
freqs_cos = freqs_cos.to(device=device)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,13 @@ def _prepare_rotary_positional_embeddings(
522522
) -> Tuple[torch.Tensor, torch.Tensor]:
523523
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
524524
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
525-
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
526-
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
525+
526+
p = self.transformer.config.patch_size
527+
p_t = self.transformer.config.patch_size_t or 1
528+
529+
base_size_width = self.transformer.config.sample_width // p
530+
base_size_height = self.transformer.config.sample_height // p
531+
base_num_frames = (num_frames + p_t - 1) // p_t
527532

528533
grid_crops_coords = get_resize_crop_region_for_grid(
529534
(grid_height, grid_width), base_size_width, base_size_height
@@ -532,7 +537,7 @@ def _prepare_rotary_positional_embeddings(
532537
embed_dim=self.transformer.config.attention_head_dim,
533538
crops_coords=grid_crops_coords,
534539
grid_size=(grid_height, grid_width),
535-
temporal_size=num_frames,
540+
temporal_size=base_num_frames,
536541
)
537542

538543
freqs_cos = freqs_cos.to(device=device)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,13 @@ def _prepare_rotary_positional_embeddings(
518518
) -> Tuple[torch.Tensor, torch.Tensor]:
519519
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
520520
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
521-
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
522-
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
521+
522+
p = self.transformer.config.patch_size
523+
p_t = self.transformer.config.patch_size_t or 1
524+
525+
base_size_width = self.transformer.config.sample_width // p
526+
base_size_height = self.transformer.config.sample_height // p
527+
base_num_frames = (num_frames + p_t - 1) // p_t
523528

524529
grid_crops_coords = get_resize_crop_region_for_grid(
525530
(grid_height, grid_width), base_size_width, base_size_height
@@ -528,7 +533,7 @@ def _prepare_rotary_positional_embeddings(
528533
embed_dim=self.transformer.config.attention_head_dim,
529534
crops_coords=grid_crops_coords,
530535
grid_size=(grid_height, grid_width),
531-
temporal_size=num_frames,
536+
temporal_size=base_num_frames,
532537
)
533538

534539
freqs_cos = freqs_cos.to(device=device)

0 commit comments

Comments
 (0)