@@ -68,7 +68,7 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu
6868 Returns:
6969 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
7070 """
71-
71+ # breakpoint()
7272 cos = cos .unsqueeze (unsqueeze_dim )
7373 sin = sin .unsqueeze (unsqueeze_dim )
7474 q_embed = (q * cos ) + (rotate_half (q ) * sin )
@@ -144,8 +144,43 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu
144144
145145
146146class QEffQwen3VLVisionModel (Qwen3VLVisionModel ):
147+ # def rot_pos_emb(self, grid_thw):
148+ # pos_ids = []
149+
150+ # bs, t, h, w = grid_thw.shape
151+
152+ # hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
153+ # hpos_ids = hpos_ids.reshape(
154+ # h // self.spatial_merge_size,
155+ # self.spatial_merge_size,
156+ # w // self.spatial_merge_size,
157+ # self.spatial_merge_size,
158+ # )
159+ # hpos_ids = hpos_ids.permute(0, 2, 1, 3)
160+ # hpos_ids = hpos_ids.flatten()
161+
162+ # wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
163+ # wpos_ids = wpos_ids.reshape(
164+ # h // self.spatial_merge_size,
165+ # self.spatial_merge_size,
166+ # w // self.spatial_merge_size,
167+ # self.spatial_merge_size,
168+ # )
169+ # wpos_ids = wpos_ids.permute(0, 2, 1, 3)
170+ # wpos_ids = wpos_ids.flatten()
171+ # pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
172+ # pos_ids = torch.cat(pos_ids, dim=0)
173+
174+ # x_expanded = pos_ids.unsqueeze(0)
175+ # x_expanded = x_expanded.expand(bs, -1, -1)
176+ # pos_ids = x_expanded.reshape(-1, pos_ids.size(1))
177+
178+ # max_grid_size = max(grid_thw.shape)
179+ # rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
180+ # rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
181+ # return rotary_pos_emb
147182 def rot_pos_emb (self , grid_thw : torch .Tensor ) -> torch .Tensor :
148- # ## breakpoint()
183+ # breakpoint()
149184 merge_size = self .spatial_merge_size
150185
151186 # max_hw = int(grid_thw[:, 1:].max().item())
@@ -160,7 +195,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
160195 total_tokens = int (torch .prod (grid_thw , dim = 1 ).sum ().item ())
161196 pos_ids = torch .empty ((total_tokens , 2 ), dtype = torch .long , device = device )
162197
163- offset = 0
198+ # offset = 0
164199 # breakpoint()
165200
166201 # for bs,num_frames, height, width in grid_thw:
@@ -182,33 +217,39 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
182217
183218 if num_frames > 1 :
184219 coords = coords .repeat (num_frames , 1 )
185-
186- num_tokens = coords .shape [0 ]
187- pos_ids [offset : offset + num_tokens ] = coords
188- offset += num_tokens
189-
220+ # breakpoint()
221+ # num_tokens = len(coords)
222+ # pos_ids[offset : offset + num_tokens] = coords
223+ # offset += num_tokens
224+ # breakpoint()
225+ # pos_ids = torch.cat([pos_ids, coords], dim=-1)
226+ pos_ids = coords
190227 embeddings = freq_table [pos_ids ] # lookup rotary embeddings
191228 embeddings = embeddings .flatten (1 )
192229 return embeddings
193230
194231 def fast_pos_embed_interpolate (self , grid_thw ):
195- # breakpoint()
196- # gridbs, grid_ts, grid_hs, grid_ws = grid_thw.shape
197232 bs , t , h , w = grid_thw .shape
198233 # grid_ts = torch.tensor([grid_ts], device=grid_thw.device)
199234 # grid_hs = torch.tensor([grid_hs], device=grid_thw.device)
200235 # grid_ws = torch.tensor([grid_ws], device=grid_thw.device)
201- idx_list = [[] for _ in range (4 )]
202- weight_list = [[] for _ in range (4 )]
236+ # idx_list = [[] for _ in range(4)]
237+ # weight_list = [[] for _ in range(4)]
203238 # t,h,w = grid_ts[0],grid_hs[0],grid_ws[0]
204239 # for t, h, w in zip(grid_ts, grid_hs, grid_ws):
205240 h_idxs = torch .linspace (0 , self .num_grid_per_side - 1 , h )
206241 w_idxs = torch .linspace (0 , self .num_grid_per_side - 1 , w )
207242
208243 h_idxs_floor = h_idxs .int ()
209244 w_idxs_floor = w_idxs .int ()
210- h_idxs_ceil = (h_idxs .int () + 1 ).clip (max = self .num_grid_per_side - 1 )
211- w_idxs_ceil = (w_idxs .int () + 1 ).clip (max = self .num_grid_per_side - 1 )
245+ # h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
246+ # w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
247+ # h_idxs_ceil = torch.clip(h_idxs.int(),0,self.num_grid_per_side - 1)
248+ # w_idxs_ceil = torch.clip(w_idxs.int(),0,self.num_grid_per_side - 1)
249+ max_t = torch .tensor (self .num_grid_per_side - 1 , device = h_idxs .device )
250+
251+ h_idxs_ceil = torch .minimum (h_idxs_floor + 1 , max_t ) # working
252+ w_idxs_ceil = torch .minimum (w_idxs_floor + 1 , max_t )
212253
213254 dh = h_idxs - h_idxs_floor
214255 dw = w_idxs - w_idxs_floor
@@ -230,12 +271,10 @@ def fast_pos_embed_interpolate(self, grid_thw):
230271 (dh [None ].T * dw [None ]).flatten (),
231272 ]
232273
233- for i in range (4 ):
234- idx_list [i ].extend (indices [i ].tolist ())
235- weight_list [i ].extend (weights [i ].tolist ())
236- idx_tensor = torch .tensor (idx_list , dtype = torch .long , device = self .pos_embed .weight .device )
237- weight_tensor = torch .tensor (
238- weight_list , dtype = self .pos_embed .weight .dtype , device = self .pos_embed .weight .device
274+ idx_tensor = torch .stack (indices , dim = 0 ).to (dtype = torch .long , device = self .pos_embed .weight .device ) # [4, h*w]
275+
276+ weight_tensor = torch .stack (weights , dim = 0 ).to (
277+ dtype = self .pos_embed .weight .dtype , device = self .pos_embed .weight .device
239278 )
240279 pos_embeds = self .pos_embed (idx_tensor ) * weight_tensor [:, :, None ]
241280 patch_pos_embeds = pos_embeds [0 ] + pos_embeds [1 ] + pos_embeds [2 ] + pos_embeds [3 ]
@@ -244,9 +283,7 @@ def fast_pos_embed_interpolate(self, grid_thw):
244283
245284 patch_pos_embeds_permute = []
246285 merge_size = self .config .spatial_merge_size
247- # breakpoint()
248286 pos_embed = patch_pos_embeds [0 ]
249- # for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
250287 pos_embed = pos_embed .repeat (t , 1 )
251288
252289 pos_embed = (
@@ -262,20 +299,18 @@ def fast_pos_embed_interpolate(self, grid_thw):
262299 return patch_pos_embeds
263300
264301 def forward (self , hidden_states : torch .Tensor , grid_thw : torch .Tensor ) -> torch .Tensor :
265- # ##breakpoint()
266302 hidden_states = self .patch_embed (hidden_states )
267303 pos_embeds = self .fast_pos_embed_interpolate (grid_thw )
268- # breakpoint()
304+
269305 hidden_states = hidden_states + pos_embeds
270- # breakpoint()
306+
271307 rotary_pos_emb = self .rot_pos_emb (grid_thw )
272308
273309 seq_len , _ = hidden_states .size ()
274310 hidden_states = hidden_states .reshape (seq_len , - 1 )
275311 rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
276312 emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
277313 position_embeddings = (emb .cos (), emb .sin ())
278- # ##breakpoint()
279314 bs , t , h , w = grid_thw .shape
280315
281316 t = torch .arange (t , t + 1 ).squeeze ().expand (bs )
@@ -286,7 +321,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
286321 dim = 0 ,
287322 dtype = torch .int32 ,
288323 )
289- # ##breakpoint()
290324 cu_seqlens = torch .cat ([torch .tensor ([0 ], dtype = cu_seqlens .dtype ), cu_seqlens ])
291325
292326 deepstack_feature_lists = []
@@ -301,9 +335,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
301335 hidden_states
302336 )
303337 deepstack_feature_lists .append (deepstack_feature )
304- # ##breakpoint()
305338 hidden_states = self .merger (hidden_states )
306- # ##breakpoint()
307339 return hidden_states , deepstack_feature_lists
308340
309341
@@ -322,7 +354,7 @@ def forward(
322354 rotary_pos_emb : Optional [torch .Tensor ] = None ,
323355 position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
324356 ) -> torch .Tensor :
325- # ## breakpoint()
357+ # breakpoint()
326358 seq_length = hidden_states .shape [0 ]
327359 q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
328360 if position_embeddings is None :
@@ -390,7 +422,7 @@ def eager_attention_forward(
390422 past_key_value : Optional [Cache ] = None ,
391423 ** kwargs ,
392424):
393- # ## breakpoint()
425+ # breakpoint()
394426 key_states = repeat_kv (key , module .num_key_value_groups )
395427 value_states = repeat_kv (value , module .num_key_value_groups )
396428
@@ -425,7 +457,7 @@ def forward(
425457 position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
426458 ** kwargs ,
427459 ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
428- ## breakpoint()
460+ # breakpoint()
429461 input_shape = hidden_states .shape [:- 1 ]
430462 hidden_shape = (* input_shape , - 1 , self .head_dim )
431463 bsz , q_len , _ = hidden_states .size ()
@@ -898,8 +930,9 @@ def get_specializations(
898930 prefill_seq_len : int ,
899931 ctx_len : int ,
900932 img_size : None ,
901- height : int = None ,
902- width : int = None ,
933+ # height: int = None,
934+ # width: int = None,
935+ dimensions : List = None ,
903936 num_frames : int = 1 ,
904937 kv_offload : bool = False ,
905938 continuous_batching : bool = False ,
@@ -910,12 +943,13 @@ def get_specializations(
910943 # ##breakpoint()
911944 comp_ctx_lengths_prefill = compiler_options .pop ("comp_ctx_lengths_prefill" , None )
912945 comp_ctx_lengths_decode = compiler_options .pop ("comp_ctx_lengths_decode" , None )
913- if height is None or width is None :
946+ if dimensions is None :
914947 height = 1365
915948 width = 2048
916949 logger .warning (
917950 "Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config"
918951 )
952+ dimensions = [[height , width ]]
919953 prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
920954 ctx_len = ctx_len if ctx_len else constants .INTERN_CTX_LEN
921955 channel = 3
@@ -972,30 +1006,53 @@ def smart_resize(
9721006 w_bar = ceil_by_factor (width * beta , factor )
9731007 return h_bar , w_bar
9741008
975- resized_height , resized_width = smart_resize (height = height , width = width )
976- grid_h , grid_w = resized_height // patch_size , resized_width // patch_size
977- grid_height = grid_h * grid_w
978- grid_width = patch_size * patch_size * temporal_patch_size * channel
979- vision_size = grid_height // 4
980- vision_size = vision_size * num_frames
981- grid_height = grid_height * batch_size
1009+ # resized_height, resized_width = smart_resize(height=height, width=width)
1010+ # grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
1011+ # grid_height = grid_h * grid_w
1012+ # grid_width = patch_size * patch_size * temporal_patch_size * channel
1013+ # vision_size = grid_height // 4
1014+ # vision_size = vision_size * num_frames
1015+ # grid_height = grid_height * batch_size
1016+ # # breakpoint()
1017+ # # vision_size = 176
1018+ # # grid_height = 704
1019+ # # grid_width = 1536
1020+ # # grid_h = 22
1021+ # # grid_w = 32
1022+ # # breakpoint()
1023+ vision = []
1024+ max_vision_size = 0
9821025 # breakpoint()
983- # vision_size = 176
984- # grid_height = 704
985- # grid_width = 1536
986- # grid_h = 22
987- # grid_w = 32
988- # breakpoint()
989- vision = [
990- {
991- "batch_size" : batch_size ,
992- "vision_size" : vision_size ,
993- "grid_height" : grid_height ,
994- "grid_width" : grid_width ,
995- "grid_h" : grid_h ,
996- "grid_w" : grid_w ,
997- }
998- ]
1026+ for dimension in dimensions :
1027+ resized_height , resized_width = smart_resize (height = dimension [0 ], width = dimension [1 ])
1028+ grid_h , grid_w = resized_height // patch_size , resized_width // patch_size
1029+ grid_height = grid_h * grid_w
1030+ grid_width = patch_size * patch_size * temporal_patch_size * channel
1031+ vision_size = grid_height // 4
1032+ vision_size = vision_size * num_frames
1033+ grid_height = grid_height * batch_size
1034+
1035+ max_vision_size = max (max_vision_size , vision_size )
1036+ # vision = [
1037+ # {
1038+ # "batch_size": batch_size,
1039+ # "vision_size": max_vision_size,
1040+ # "grid_height": grid_height,
1041+ # "grid_width": grid_width,
1042+ # "grid_h": grid_h,
1043+ # "grid_w": grid_w,
1044+ # }
1045+ # ]
1046+ vision .append (
1047+ {
1048+ "batch_size" : batch_size ,
1049+ "vision_size" : vision_size ,
1050+ "grid_height" : grid_height ,
1051+ "grid_width" : grid_width ,
1052+ "grid_h" : grid_h ,
1053+ "grid_w" : grid_w ,
1054+ }
1055+ )
9991056 # ##breakpoint()
10001057
10011058 if comp_ctx_lengths_prefill is not None :
@@ -1006,7 +1063,7 @@ def smart_resize(
10061063 "batch_size" : 1 if continuous_batching else batch_size ,
10071064 "seq_len" : prefill_seq_len ,
10081065 "ctx_len" : ctx_len ,
1009- "vision_size" : vision_size ,
1066+ "vision_size" : max_vision_size ,
10101067 "comp_ctx_lengths" : comp_ctx_lengths_prefill [i ],
10111068 "vision_batch_size" : batch_size ,
10121069 }
@@ -1025,7 +1082,7 @@ def smart_resize(
10251082 "batch_size" : full_batch_size if continuous_batching else batch_size ,
10261083 "seq_len" : "1" ,
10271084 "ctx_len" : ctx_len ,
1028- "vision_size" : vision_size ,
1085+ "vision_size" : max_vision_size ,
10291086 "comp_ctx_lengths" : comp_ctx_lengths_decode [i ],
10301087 "vision_batch_size" : batch_size ,
10311088 }
@@ -1041,7 +1098,7 @@ def smart_resize(
10411098 "batch_size" : 1 if continuous_batching else batch_size ,
10421099 "seq_len" : prefill_seq_len ,
10431100 "ctx_len" : ctx_len ,
1044- "vision_size" : vision_size ,
1101+ "vision_size" : max_vision_size ,
10451102 "vision_batch_size" : batch_size ,
10461103 }
10471104
@@ -1056,7 +1113,7 @@ def smart_resize(
10561113 "batch_size" : full_batch_size if continuous_batching else batch_size ,
10571114 "seq_len" : 1 ,
10581115 "ctx_len" : ctx_len ,
1059- "vision_size" : vision_size ,
1116+ "vision_size" : max_vision_size ,
10601117 "vision_batch_size" : batch_size ,
10611118 }
10621119
0 commit comments