Skip to content

Commit af7dfb0

Browse files
authored
[Perf] Further optimization for Qwen3-VL fast_pos_embed_interpolate (#25347)
Signed-off-by: Isotr0py <[email protected]>
1 parent 1c3ffdb commit af7dfb0

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

vllm/model_executor/models/qwen3_vl.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -405,25 +405,39 @@ def fast_pos_embed_interpolate(self,
405405
dh = h_idxs - h_floor
406406
dw = w_idxs - w_floor
407407

408-
w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
409-
w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
410-
w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
411-
w11 = (dh[:, None] * dw[None, :]).reshape(-1)
412-
413-
idx00 = (h_floor[:, None] * num_grid_per_side +
414-
w_floor[None, :]).reshape(-1)
415-
idx01 = (h_floor[:, None] * num_grid_per_side +
416-
w_ceil[None, :]).reshape(-1)
417-
idx10 = (h_ceil[:, None] * num_grid_per_side +
418-
w_floor[None, :]).reshape(-1)
419-
idx11 = (h_ceil[:, None] * num_grid_per_side +
420-
w_ceil[None, :]).reshape(-1)
421-
422-
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0)
408+
# Create meshgrid view for all h, w vars
409+
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
410+
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor,
411+
w_floor,
412+
indexing='ij')
413+
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil,
414+
w_ceil,
415+
indexing='ij')
416+
h_floor_grid_idx = h_floor_grid * num_grid_per_side
417+
h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
418+
419+
# original computation of weights
420+
# w00 = (1 - dh_grid) * (1 - dw_grid)
421+
# w01 = (1 - dh_grid) * dw_grid
422+
# w10 = dh_grid * (1 - dw_grid)
423+
# w11 = dh_grid * dw_grid
424+
# we reuse w11 here to avoid duplicate
425+
# dh_grid * dw_grid computation
426+
w11 = dh_grid * dw_grid
427+
w10 = dh_grid - w11
428+
w01 = dw_grid - w11
429+
w00 = 1 - dh_grid - dw_grid + w11
430+
431+
idx00 = h_floor_grid_idx + w_floor_grid
432+
idx01 = h_floor_grid_idx + w_ceil_grid
433+
idx10 = h_ceil_grid_idx + w_floor_grid
434+
idx11 = h_ceil_grid_idx + w_ceil_grid
435+
436+
indices = torch.stack([idx00, idx01, idx10, idx11],
437+
dim=0).reshape(4, -1)
423438
weights = torch.stack([w00, w01, w10, w11],
424-
dim=0).to(dtype=self.dtype,
425-
device=self.device)
426-
weights = weights.unsqueeze(-1)
439+
dim=0).reshape(4, -1, 1)
440+
weights = weights.to(dtype=self.dtype, device=self.device)
427441

428442
embeds = self.pos_embed(indices)
429443
weighted_embeds = embeds * weights

0 commit comments

Comments
 (0)