@@ -2611,3 +2611,78 @@ def forward(self, image_embeds: List[torch.Tensor]):
26112611 projected_image_embeds .append (image_embed )
26122612
26132613 return projected_image_embeds
2614+
2615+
2616+ class CogViewRotary2DEmbedding (nn .Module ):
2617+ def __init__ (
2618+ self ,
2619+ kv_channels : int ,
2620+ rotary_percent : float ,
2621+ max_h : int = 128 ,
2622+ max_w : int = 128 ,
2623+ rotary_interleaved : bool = False ,
2624+ seq_len_interpolation_factor : float = None ,
2625+ inner_interp : bool = False ,
2626+ rotary_base : int = 10000 ,
2627+ ) -> None :
2628+ super ().__init__ ()
2629+
2630+ dim = kv_channels
2631+ if rotary_percent < 1.0 :
2632+ dim = int (dim * rotary_percent )
2633+ self .rotary_interleaved = rotary_interleaved
2634+
2635+ self .seq_len_interpolation_factor = seq_len_interpolation_factor
2636+ self .inner_interp = inner_interp
2637+
2638+ dim_h = kv_channels // 2
2639+ dim_w = kv_channels // 2
2640+
2641+ device = torch .cuda .current_device ()
2642+ h_inv_freq = 1.0 / (
2643+ rotary_base
2644+ ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 , device = device )[: (dim_h // 2 )].float () / dim_h )
2645+ )
2646+ w_inv_freq = 1.0 / (
2647+ rotary_base
2648+ ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 , device = device )[: (dim_w // 2 )].float () / dim_w )
2649+ )
2650+
2651+ h_seq = torch .arange (max_h , device = device , dtype = h_inv_freq .dtype )
2652+ w_seq = torch .arange (max_w , device = device , dtype = w_inv_freq .dtype )
2653+
2654+ self .freqs_h = torch .outer (h_seq , h_inv_freq )
2655+ self .freqs_w = torch .outer (w_seq , w_inv_freq )
2656+ self .max_h = max_h
2657+ self .max_w = max_w
2658+
2659+ def forward (
2660+ self ,
2661+ h_idx : torch .Tensor ,
2662+ w_idx : torch .Tensor ,
2663+ target_h : torch .Tensor = None ,
2664+ target_w : torch .Tensor = None ,
2665+ mask : torch .Tensor = None ,
2666+ ) -> torch .Tensor :
2667+ if self .inner_interp :
2668+ inner_h_idx = (h_idx * self .max_h ) // target_h
2669+ inner_w_idx = (w_idx * self .max_w ) // target_w
2670+
2671+ h_emb = self .freqs_h [inner_h_idx ]
2672+ w_emb = self .freqs_w [inner_w_idx ]
2673+
2674+ else :
2675+ h_emb = self .freqs_h [h_idx ]
2676+ w_emb = self .freqs_w [w_idx ]
2677+
2678+ mask = (mask == 1 ).unsqueeze (- 1 )
2679+
2680+ emb = torch .cat ([h_emb , w_emb ], dim = - 1 ) * mask
2681+
2682+ assert emb .ndim == 2 , f"expected emb to have 2 dimensions, got { emb .ndim } "
2683+ if not self .rotary_interleaved :
2684+ emb = torch .repeat_interleave (emb , 2 , dim = 0 )
2685+ else :
2686+ emb = torch .repeat_interleave (emb , 2 , dim = 1 )
2687+
2688+ return emb
0 commit comments