@@ -2090,11 +2090,44 @@ def prepare_inputs_for_generation(
20902090 return model_inputs
20912091
20922092 class patched_Qwen2_5_VisionTransformerPretrainedModel :
2093- _PATCHES_ = ["get_window_index" , "forward" ]
2093+ _PATCHES_ = ["get_window_index" , "forward" , "rot_pos_emb" ]
20942094 _PATCHED_CLASS_ = (
20952095 transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VisionTransformerPretrainedModel
20962096 )
20972097
2098+ def rot_pos_emb (self , grid_thw ):
2099+ pos_ids = []
2100+ for thw in grid_thw :
2101+ # PATCHED: avoid unbind
2102+ t = thw [0 ]
2103+ h = thw [1 ]
2104+ w = thw [2 ]
2105+ hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
2106+ hpos_ids = hpos_ids .reshape (
2107+ h // self .spatial_merge_size ,
2108+ self .spatial_merge_size ,
2109+ w // self .spatial_merge_size ,
2110+ self .spatial_merge_size ,
2111+ )
2112+ hpos_ids = hpos_ids .permute (0 , 2 , 1 , 3 )
2113+ hpos_ids = hpos_ids .flatten ()
2114+
2115+ wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
2116+ wpos_ids = wpos_ids .reshape (
2117+ h // self .spatial_merge_size ,
2118+ self .spatial_merge_size ,
2119+ w // self .spatial_merge_size ,
2120+ self .spatial_merge_size ,
2121+ )
2122+ wpos_ids = wpos_ids .permute (0 , 2 , 1 , 3 )
2123+ wpos_ids = wpos_ids .flatten ()
2124+ pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
2125+ pos_ids = torch .cat (pos_ids , dim = 0 )
2126+ max_grid_size = grid_thw [:, 1 :].max ()
2127+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
2128+ rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
2129+ return rotary_pos_emb
2130+
20982131 def get_window_index (self , grid_thw ):
20992132 window_index : list = []
21002133 # PATCHED
@@ -2104,7 +2137,11 @@ def get_window_index(self, grid_thw):
21042137 self .window_size // self .spatial_merge_size // self .patch_size
21052138 )
21062139
2107- for grid_t , grid_h , grid_w in grid_thw :
2140+ for _thw in grid_thw :
2141+ # PATCHED: avoid unbind
2142+ grid_t = _thw [0 ]
2143+ grid_h = _thw [1 ]
2144+ grid_w = _thw [2 ]
21082145 llm_grid_h , llm_grid_w = (
21092146 grid_h // self .spatial_merge_size ,
21102147 grid_w // self .spatial_merge_size ,
@@ -2279,12 +2316,13 @@ def forward(
22792316 ** kwargs ,
22802317 ) -> torch .Tensor :
22812318 seq_length = hidden_states .shape [0 ]
2282- query_states , key_states , value_states = (
2319+ # PATCHED: avoid the use of unbind
2320+ qkv = (
22832321 self .qkv (hidden_states )
22842322 .reshape (seq_length , 3 , self .num_heads , - 1 )
22852323 .permute (1 , 0 , 2 , 3 )
2286- .unbind (0 )
22872324 )
2325+ query_states , key_states , value_states = qkv [0 ], qkv [1 ], qkv [2 ]
22882326 cos , sin = position_embeddings
22892327 query_states , key_states = (
22902328 transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .apply_rotary_pos_emb_vision (
0 commit comments