Skip to content

Commit 07cadab

Browse files
lgeigerywang96
andauthored
[Model][Qwen3VL] Cache positional embedding indices (#28475)
Signed-off-by: Lukas Geiger <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 637f292 commit 07cadab

File tree

1 file changed

+34
-23
lines changed

1 file changed

+34
-23
lines changed

vllm/model_executor/models/qwen3_vl.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
2626

2727
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
28-
from functools import partial
28+
from functools import lru_cache, partial
2929
from itertools import islice
3030
from typing import Any
3131

@@ -416,30 +416,41 @@ def dtype(self) -> torch.dtype:
416416
def device(self) -> torch.device:
417417
return self.patch_embed.proj.weight.device
418418

419+
@staticmethod
420+
@lru_cache(maxsize=1024)
421+
def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
422+
hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
423+
h_div = h // spatial_merge_size
424+
w_div = w // spatial_merge_size
425+
hpos_ids = hpos_ids.reshape(
426+
h_div,
427+
spatial_merge_size,
428+
w_div,
429+
spatial_merge_size,
430+
)
431+
hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
432+
hpos_ids = hpos_ids.flatten()
433+
434+
wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
435+
wpos_ids = wpos_ids.reshape(
436+
h_div,
437+
spatial_merge_size,
438+
w_div,
439+
spatial_merge_size,
440+
)
441+
wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
442+
wpos_ids = wpos_ids.flatten()
443+
444+
return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))
445+
419446
def rot_pos_emb(self, grid_thw: list[list[int]]):
420-
pos_ids = []
421447
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
422-
for t, h, w in grid_thw:
423-
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
424-
hpos_ids = hpos_ids.reshape(
425-
h // self.spatial_merge_size,
426-
self.spatial_merge_size,
427-
w // self.spatial_merge_size,
428-
self.spatial_merge_size,
429-
)
430-
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
431-
hpos_ids = hpos_ids.flatten()
432-
433-
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
434-
wpos_ids = wpos_ids.reshape(
435-
h // self.spatial_merge_size,
436-
self.spatial_merge_size,
437-
w // self.spatial_merge_size,
438-
self.spatial_merge_size,
439-
)
440-
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
441-
wpos_ids = wpos_ids.flatten()
442-
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
448+
pos_ids = [
449+
self.rot_pos_ids(h, w, self.spatial_merge_size)
450+
if t == 1
451+
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
452+
for t, h, w in grid_thw
453+
]
443454
pos_ids = torch.cat(pos_ids, dim=0)
444455
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
445456
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)

0 commit comments

Comments
 (0)