Skip to content

Commit 25a9e1c

Browse files
committed
fix
1 parent 3dba37f commit 25a9e1c

File tree

2 files changed

+60
-19
lines changed

2 files changed

+60
-19
lines changed

src/diffusers/models/embeddings.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,14 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
517517

518518

519519
def get_3d_rotary_pos_embed(
520-
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
520+
embed_dim,
521+
crops_coords,
522+
grid_size,
523+
temporal_size,
524+
theta: int = 10000,
525+
use_real: bool = True,
526+
grid_type: str = "linspace",
527+
max_size: Optional[Tuple[int, int]] = None,
521528
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
522529
"""
523530
RoPE for video tokens with 3D structure.
@@ -533,17 +540,30 @@ def get_3d_rotary_pos_embed(
533540
The size of the temporal dimension.
534541
theta (`float`):
535542
Scaling factor for frequency computation.
543+
grid_type (`str`):
544+
Whether to use "linspace" or "slice" to compute grids.
536545
537546
Returns:
538547
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
539548
"""
540549
if use_real is not True:
541550
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
542-
start, stop = crops_coords
543-
grid_size_h, grid_size_w = grid_size
544-
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
545-
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
546-
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
551+
552+
if grid_type == "linspace":
553+
start, stop = crops_coords
554+
grid_size_h, grid_size_w = grid_size
555+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
556+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
557+
grid_t = np.arange(temporal_size, dtype=np.float32)
558+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
559+
elif grid_type == "slice":
560+
max_h, max_w = max_size
561+
grid_size_h, grid_size_w = grid_size
562+
grid_h = np.arange(max_h, dtype=np.float32)
563+
grid_w = np.arange(max_w, dtype=np.float32)
564+
grid_t = np.arange(temporal_size, dtype=np.float32)
565+
else:
566+
raise ValueError("Invalid value passed for `grid_type`.")
547567

548568
# Compute dimensions for each axis
549569
dim_t = embed_dim // 4
@@ -579,6 +599,12 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
579599
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
580600
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
581601
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
602+
603+
if grid_type == "slice":
604+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
605+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
606+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
607+
582608
cos = combine_time_height_width(t_cos, h_cos, w_cos)
583609
sin = combine_time_height_width(t_sin, h_sin, w_sin)
584610
return cos, sin

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -540,21 +540,36 @@ def _prepare_rotary_positional_embeddings(
540540
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
541541

542542
p = self.transformer.config.patch_size
543-
p_t = self.transformer.config.patch_size_t or 1
543+
p_t = self.transformer.config.patch_size_t
544544

545-
base_size_width = self.transformer.config.sample_width // p
546-
base_size_height = self.transformer.config.sample_height // p
547-
base_num_frames = (num_frames + p_t - 1) // p_t
545+
if p_t is None:
546+
# CogVideoX 1.0 I2V
547+
base_size_width = self.transformer.config.sample_width // p
548+
base_size_height = self.transformer.config.sample_height // p
548549

549-
grid_crops_coords = get_resize_crop_region_for_grid(
550-
(grid_height, grid_width), base_size_width, base_size_height
551-
)
552-
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
553-
embed_dim=self.transformer.config.attention_head_dim,
554-
crops_coords=grid_crops_coords,
555-
grid_size=(grid_height, grid_width),
556-
temporal_size=base_num_frames,
557-
)
550+
grid_crops_coords = get_resize_crop_region_for_grid(
551+
(grid_height, grid_width), base_size_width, base_size_height
552+
)
553+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
554+
embed_dim=self.transformer.config.attention_head_dim,
555+
crops_coords=grid_crops_coords,
556+
grid_size=(grid_height, grid_width),
557+
temporal_size=num_frames,
558+
)
559+
else:
560+
# CogVideoX 1.5 I2V
561+
base_size_width = self.transformer.config.sample_width // p
562+
base_size_height = self.transformer.config.sample_height // p
563+
base_num_frames = (num_frames + p_t - 1) // p_t
564+
565+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
566+
embed_dim=self.transformer.config.attention_head_dim,
567+
crops_coords=None,
568+
grid_size=(grid_height, grid_width),
569+
temporal_size=base_num_frames,
570+
grid_type="slice",
571+
max_size=(base_size_height, base_size_width),
572+
)
558573

559574
freqs_cos = freqs_cos.to(device=device)
560575
freqs_sin = freqs_sin.to(device=device)

0 commit comments

Comments
 (0)