@@ -617,7 +617,7 @@ def dtype(self) -> torch.dtype:
617
617
def device (self ) -> torch .device :
618
618
return self .patch_embed .patchifier .proj .weight .device
619
619
620
- def get_pos_ids_by_grid (self , grid_thw ) :
620
+ def get_pos_ids_by_grid (self , grid_thw : list [ list [ int ]]) -> list [ torch . Tensor ] :
621
621
pos_ids = []
622
622
for t , h , w in grid_thw :
623
623
hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
@@ -643,10 +643,10 @@ def get_pos_ids_by_grid(self, grid_thw):
643
643
644
644
return pos_ids
645
645
646
- def rot_pos_emb (self , grid_thw ) :
646
+ def rot_pos_emb (self , grid_thw : list [ list [ int ]]) -> torch . Tensor :
647
647
pos_ids = self .get_pos_ids_by_grid (grid_thw )
648
648
pos_ids = torch .cat (pos_ids , dim = 0 )
649
- max_grid_size = grid_thw [:, 1 :]. max ()
649
+ max_grid_size = max (max ( h , w ) for _ , h , w in grid_thw )
650
650
rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
651
651
rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
652
652
return rotary_pos_emb
@@ -667,13 +667,13 @@ def compute_attn_mask_seqlen(
667
667
def forward (
668
668
self , hidden_states : torch .Tensor , grid_thw : list [list [int ]]
669
669
) -> torch .Tensor :
670
+ rotary_pos_emb = self .rot_pos_emb (grid_thw )
671
+
670
672
# Convert grid_thw to tensor (always expecting list format now)
671
673
grid_thw = torch .tensor (grid_thw , device = hidden_states .device , dtype = torch .long )
672
674
hidden_states = hidden_states .to (self .dtype )
673
675
hidden_states = self .patch_embed (hidden_states , grid_thw )
674
676
675
- rotary_pos_emb = self .rot_pos_emb (grid_thw )
676
-
677
677
cu_seqlens = torch .repeat_interleave (
678
678
grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]
679
679
).cumsum (
@@ -807,7 +807,7 @@ def _process_image_input(
807
807
rope_type = "rope_3d" ,
808
808
)
809
809
else :
810
- image_embeds = self .vision_tower (pixel_values , grid_thw )[
810
+ image_embeds = self .vision_tower (pixel_values , grid_thw_list )[
811
811
:, : self .config .hidden_size
812
812
]
813
813
0 commit comments