@@ -192,23 +192,20 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
192192 return embeddings
193193
194194 def fast_pos_embed_interpolate (self , grid_thw ):
195- # breakpoint()
196- # gridbs, grid_ts, grid_hs, grid_ws = grid_thw.shape
197- bs , t , h , w = grid_thw .shape
198- # grid_ts = torch.tensor([grid_ts], device=grid_thw.device)
199- # grid_hs = torch.tensor([grid_hs], device=grid_thw.device)
200- # grid_ws = torch.tensor([grid_ws], device=grid_thw.device)
201- idx_list = [[] for _ in range (4 )]
202- weight_list = [[] for _ in range (4 )]
203- # t,h,w = grid_ts[0],grid_hs[0],grid_ws[0]
204- # for t, h, w in zip(grid_ts, grid_hs, grid_ws):
195+ bs ,t ,h ,w = grid_thw .shape
196+
205197 h_idxs = torch .linspace (0 , self .num_grid_per_side - 1 , h )
206198 w_idxs = torch .linspace (0 , self .num_grid_per_side - 1 , w )
207199
208200 h_idxs_floor = h_idxs .int ()
209201 w_idxs_floor = w_idxs .int ()
210- h_idxs_ceil = (h_idxs .int () + 1 ).clip (max = self .num_grid_per_side - 1 )
211- w_idxs_ceil = (w_idxs .int () + 1 ).clip (max = self .num_grid_per_side - 1 )
202+ # h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
203+ # w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
204+ # TO resolve clip issue
205+ max_t = torch .tensor (self .num_grid_per_side - 1 , device = h_idxs .device )
206+
207+ h_idxs_ceil = torch .minimum (h_idxs_floor + 1 , max_t ) # working
208+ w_idxs_ceil = torch .minimum (w_idxs_floor + 1 , max_t )
212209
213210 dh = h_idxs - h_idxs_floor
214211 dw = w_idxs - w_idxs_floor
@@ -230,12 +227,12 @@ def fast_pos_embed_interpolate(self, grid_thw):
230227 (dh [None ].T * dw [None ]).flatten (),
231228 ]
232229
233- for i in range ( 4 ):
234- idx_list [ i ]. extend ( indices [ i ]. tolist ())
235- weight_list [ i ]. extend ( weights [ i ]. tolist ())
236- idx_tensor = torch . tensor ( idx_list , dtype = torch . long , device = self . pos_embed . weight . device )
237- weight_tensor = torch .tensor (
238- weight_list , dtype = self .pos_embed .weight .dtype , device = self .pos_embed .weight .device
230+ idx_tensor = torch . stack ( indices , dim = 0 ). to (
231+ dtype = torch . long , device = self . pos_embed . weight . device
232+ ) # [4, h*w]
233+
234+ weight_tensor = torch .stack ( weights , dim = 0 ). to (
235+ dtype = self .pos_embed .weight .dtype , device = self .pos_embed .weight .device
239236 )
240237 pos_embeds = self .pos_embed (idx_tensor ) * weight_tensor [:, :, None ]
241238 patch_pos_embeds = pos_embeds [0 ] + pos_embeds [1 ] + pos_embeds [2 ] + pos_embeds [3 ]
@@ -244,9 +241,7 @@ def fast_pos_embed_interpolate(self, grid_thw):
244241
245242 patch_pos_embeds_permute = []
246243 merge_size = self .config .spatial_merge_size
247- # breakpoint()
248244 pos_embed = patch_pos_embeds [0 ]
249- # for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
250245 pos_embed = pos_embed .repeat (t , 1 )
251246
252247 pos_embed = (
0 commit comments