Skip to content

Commit f4457fb

Browse files
committed
feat(embeddings): add CogView 2D rotary positional embedding
1 parent e9f6626 commit f4457fb

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

src/diffusers/models/embeddings.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2611,3 +2611,78 @@ def forward(self, image_embeds: List[torch.Tensor]):
26112611
projected_image_embeds.append(image_embed)
26122612

26132613
return projected_image_embeds
2614+
2615+
2616+
class CogViewRotary2DEmbedding(nn.Module):
2617+
def __init__(
2618+
self,
2619+
kv_channels: int,
2620+
rotary_percent: float,
2621+
max_h: int = 128,
2622+
max_w: int = 128,
2623+
rotary_interleaved: bool = False,
2624+
seq_len_interpolation_factor: float = None,
2625+
inner_interp: bool = False,
2626+
rotary_base: int = 10000,
2627+
) -> None:
2628+
super().__init__()
2629+
2630+
dim = kv_channels
2631+
if rotary_percent < 1.0:
2632+
dim = int(dim * rotary_percent)
2633+
self.rotary_interleaved = rotary_interleaved
2634+
2635+
self.seq_len_interpolation_factor = seq_len_interpolation_factor
2636+
self.inner_interp = inner_interp
2637+
2638+
dim_h = kv_channels // 2
2639+
dim_w = kv_channels // 2
2640+
2641+
device = torch.cuda.current_device()
2642+
h_inv_freq = 1.0 / (
2643+
rotary_base
2644+
** (torch.arange(0, dim_h, 2, dtype=torch.float32, device=device)[: (dim_h // 2)].float() / dim_h)
2645+
)
2646+
w_inv_freq = 1.0 / (
2647+
rotary_base
2648+
** (torch.arange(0, dim_w, 2, dtype=torch.float32, device=device)[: (dim_w // 2)].float() / dim_w)
2649+
)
2650+
2651+
h_seq = torch.arange(max_h, device=device, dtype=h_inv_freq.dtype)
2652+
w_seq = torch.arange(max_w, device=device, dtype=w_inv_freq.dtype)
2653+
2654+
self.freqs_h = torch.outer(h_seq, h_inv_freq)
2655+
self.freqs_w = torch.outer(w_seq, w_inv_freq)
2656+
self.max_h = max_h
2657+
self.max_w = max_w
2658+
2659+
def forward(
2660+
self,
2661+
h_idx: torch.Tensor,
2662+
w_idx: torch.Tensor,
2663+
target_h: torch.Tensor = None,
2664+
target_w: torch.Tensor = None,
2665+
mask: torch.Tensor = None,
2666+
) -> torch.Tensor:
2667+
if self.inner_interp:
2668+
inner_h_idx = (h_idx * self.max_h) // target_h
2669+
inner_w_idx = (w_idx * self.max_w) // target_w
2670+
2671+
h_emb = self.freqs_h[inner_h_idx]
2672+
w_emb = self.freqs_w[inner_w_idx]
2673+
2674+
else:
2675+
h_emb = self.freqs_h[h_idx]
2676+
w_emb = self.freqs_w[w_idx]
2677+
2678+
mask = (mask == 1).unsqueeze(-1)
2679+
2680+
emb = torch.cat([h_emb, w_emb], dim=-1) * mask
2681+
2682+
assert emb.ndim == 2, f"expected emb to have 2 dimensions, got {emb.ndim}"
2683+
if not self.rotary_interleaved:
2684+
emb = torch.repeat_interleave(emb, 2, dim=0)
2685+
else:
2686+
emb = torch.repeat_interleave(emb, 2, dim=1)
2687+
2688+
return emb

0 commit comments

Comments
 (0)