Skip to content

Commit 747ddee

Browse files
committed
Fixing issue in fast_pos_embed
Signed-off-by: vtirumal <vtirumal@qti.qualcomm.com>
1 parent 7e21c0f commit 747ddee

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -192,23 +192,20 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
192192
return embeddings
193193

194194
def fast_pos_embed_interpolate(self, grid_thw):
195-
# breakpoint()
196-
# gridbs, grid_ts, grid_hs, grid_ws = grid_thw.shape
197-
bs, t, h, w = grid_thw.shape
198-
# grid_ts = torch.tensor([grid_ts], device=grid_thw.device)
199-
# grid_hs = torch.tensor([grid_hs], device=grid_thw.device)
200-
# grid_ws = torch.tensor([grid_ws], device=grid_thw.device)
201-
idx_list = [[] for _ in range(4)]
202-
weight_list = [[] for _ in range(4)]
203-
# t,h,w = grid_ts[0],grid_hs[0],grid_ws[0]
204-
# for t, h, w in zip(grid_ts, grid_hs, grid_ws):
195+
bs,t,h,w=grid_thw.shape
196+
205197
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
206198
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
207199

208200
h_idxs_floor = h_idxs.int()
209201
w_idxs_floor = w_idxs.int()
210-
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
211-
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
202+
# h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
203+
# w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
204+
# TO resolve clip issue
205+
max_t = torch.tensor(self.num_grid_per_side - 1, device=h_idxs.device)
206+
207+
h_idxs_ceil = torch.minimum(h_idxs_floor + 1, max_t) # working
208+
w_idxs_ceil = torch.minimum(w_idxs_floor + 1, max_t)
212209

213210
dh = h_idxs - h_idxs_floor
214211
dw = w_idxs - w_idxs_floor
@@ -230,12 +227,12 @@ def fast_pos_embed_interpolate(self, grid_thw):
230227
(dh[None].T * dw[None]).flatten(),
231228
]
232229

233-
for i in range(4):
234-
idx_list[i].extend(indices[i].tolist())
235-
weight_list[i].extend(weights[i].tolist())
236-
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
237-
weight_tensor = torch.tensor(
238-
weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
230+
idx_tensor = torch.stack(indices, dim=0).to(
231+
dtype=torch.long, device=self.pos_embed.weight.device
232+
) # [4, h*w]
233+
234+
weight_tensor = torch.stack(weights, dim=0).to(
235+
dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
239236
)
240237
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
241238
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
@@ -244,9 +241,7 @@ def fast_pos_embed_interpolate(self, grid_thw):
244241

245242
patch_pos_embeds_permute = []
246243
merge_size = self.config.spatial_merge_size
247-
# breakpoint()
248244
pos_embed = patch_pos_embeds[0]
249-
# for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
250245
pos_embed = pos_embed.repeat(t, 1)
251246

252247
pos_embed = (

0 commit comments

Comments
 (0)