@@ -405,25 +405,39 @@ def fast_pos_embed_interpolate(self,
405
405
dh = h_idxs - h_floor
406
406
dw = w_idxs - w_floor
407
407
408
- w00 = ((1 - dh )[:, None ] * (1 - dw )[None , :]).reshape (- 1 )
409
- w01 = ((1 - dh )[:, None ] * dw [None , :]).reshape (- 1 )
410
- w10 = (dh [:, None ] * (1 - dw )[None , :]).reshape (- 1 )
411
- w11 = (dh [:, None ] * dw [None , :]).reshape (- 1 )
412
-
413
- idx00 = (h_floor [:, None ] * num_grid_per_side +
414
- w_floor [None , :]).reshape (- 1 )
415
- idx01 = (h_floor [:, None ] * num_grid_per_side +
416
- w_ceil [None , :]).reshape (- 1 )
417
- idx10 = (h_ceil [:, None ] * num_grid_per_side +
418
- w_floor [None , :]).reshape (- 1 )
419
- idx11 = (h_ceil [:, None ] * num_grid_per_side +
420
- w_ceil [None , :]).reshape (- 1 )
421
-
422
- indices = torch .stack ([idx00 , idx01 , idx10 , idx11 ], dim = 0 )
408
+ # Create meshgrid view for all h, w vars
409
+ dh_grid , dw_grid = torch .meshgrid (dh , dw , indexing = 'ij' )
410
+ h_floor_grid , w_floor_grid = torch .meshgrid (h_floor ,
411
+ w_floor ,
412
+ indexing = 'ij' )
413
+ h_ceil_grid , w_ceil_grid = torch .meshgrid (h_ceil ,
414
+ w_ceil ,
415
+ indexing = 'ij' )
416
+ h_floor_grid_idx = h_floor_grid * num_grid_per_side
417
+ h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
418
+
419
+ # original computation of weights
420
+ # w00 = (1 - dh_grid) * (1 - dw_grid)
421
+ # w01 = (1 - dh_grid) * dw_grid
422
+ # w10 = dh_grid * (1 - dw_grid)
423
+ # w11 = dh_grid * dw_grid
424
+ # we reuse w11 here to avoid duplicate
425
+ # dh_grid * dw_grid computation
426
+ w11 = dh_grid * dw_grid
427
+ w10 = dh_grid - w11
428
+ w01 = dw_grid - w11
429
+ w00 = 1 - dh_grid - dw_grid + w11
430
+
431
+ idx00 = h_floor_grid_idx + w_floor_grid
432
+ idx01 = h_floor_grid_idx + w_ceil_grid
433
+ idx10 = h_ceil_grid_idx + w_floor_grid
434
+ idx11 = h_ceil_grid_idx + w_ceil_grid
435
+
436
+ indices = torch .stack ([idx00 , idx01 , idx10 , idx11 ],
437
+ dim = 0 ).reshape (4 , - 1 )
423
438
weights = torch .stack ([w00 , w01 , w10 , w11 ],
424
- dim = 0 ).to (dtype = self .dtype ,
425
- device = self .device )
426
- weights = weights .unsqueeze (- 1 )
439
+ dim = 0 ).reshape (4 , - 1 , 1 )
440
+ weights = weights .to (dtype = self .dtype , device = self .device )
427
441
428
442
embeds = self .pos_embed (indices )
429
443
weighted_embeds = embeds * weights
0 commit comments