@@ -517,7 +517,14 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
517517
518518
519519def get_3d_rotary_pos_embed (
520- embed_dim , crops_coords , grid_size , temporal_size , theta : int = 10000 , use_real : bool = True
520+ embed_dim ,
521+ crops_coords ,
522+ grid_size ,
523+ temporal_size ,
524+ theta : int = 10000 ,
525+ use_real : bool = True ,
526+ grid_type : str = "linspace" ,
527+ max_size : Optional [Tuple [int , int ]] = None ,
521528) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
522529 """
523530 RoPE for video tokens with 3D structure.
@@ -533,17 +540,30 @@ def get_3d_rotary_pos_embed(
533540 The size of the temporal dimension.
534541 theta (`float`):
535542 Scaling factor for frequency computation.
543+ grid_type (`str`):
544+ Whether to use "linspace" or "slice" to compute grids.
536545
537546 Returns:
538547 `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
539548 """
540549 if use_real is not True :
541550 raise ValueError (" `use_real = False` is not currently supported for get_3d_rotary_pos_embed" )
542- start , stop = crops_coords
543- grid_size_h , grid_size_w = grid_size
544- grid_h = np .linspace (start [0 ], stop [0 ], grid_size_h , endpoint = False , dtype = np .float32 )
545- grid_w = np .linspace (start [1 ], stop [1 ], grid_size_w , endpoint = False , dtype = np .float32 )
546- grid_t = np .linspace (0 , temporal_size , temporal_size , endpoint = False , dtype = np .float32 )
551+
552+ if grid_type == "linspace" :
553+ start , stop = crops_coords
554+ grid_size_h , grid_size_w = grid_size
555+ grid_h = np .linspace (start [0 ], stop [0 ], grid_size_h , endpoint = False , dtype = np .float32 )
556+ grid_w = np .linspace (start [1 ], stop [1 ], grid_size_w , endpoint = False , dtype = np .float32 )
557+ grid_t = np .arange (temporal_size , dtype = np .float32 )
558+ grid_t = np .linspace (0 , temporal_size , temporal_size , endpoint = False , dtype = np .float32 )
559+ elif grid_type == "slice" :
560+ max_h , max_w = max_size
561+ grid_size_h , grid_size_w = grid_size
562+ grid_h = np .arange (max_h , dtype = np .float32 )
563+ grid_w = np .arange (max_w , dtype = np .float32 )
564+ grid_t = np .arange (temporal_size , dtype = np .float32 )
565+ else :
566+ raise ValueError ("Invalid value passed for `grid_type`." )
547567
548568 # Compute dimensions for each axis
549569 dim_t = embed_dim // 4
@@ -579,6 +599,12 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
579599 t_cos , t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
580600 h_cos , h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
581601 w_cos , w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
602+
603+ if grid_type == "slice" :
604+ t_cos , t_sin = t_cos [:temporal_size ], t_sin [:temporal_size ]
605+ h_cos , h_sin = h_cos [:grid_size_h ], h_sin [:grid_size_h ]
606+ w_cos , w_sin = w_cos [:grid_size_w ], w_sin [:grid_size_w ]
607+
582608 cos = combine_time_height_width (t_cos , h_cos , w_cos )
583609 sin = combine_time_height_width (t_sin , h_sin , w_sin )
584610 return cos , sin
0 commit comments