Skip to content

Commit 72a1c89

Browse files
david6666666Junhong
authored andcommitted
[Performance][MM] Building the inverse permutation in O(n) time in Qwen2_5_VisionTransformer (vllm-project#24443)
Signed-off-by: Junhong <[email protected]> Co-authored-by: Junhong <[email protected]>
1 parent ba81b6d commit 72a1c89

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,15 @@ def compute_attn_mask_seqlen(
717717
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
718718
return max_seqlen, seqlens
719719

720+
@staticmethod
721+
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
722+
# building the inverse permutation in O(n) time
723+
inv = torch.empty_like(perm)
724+
inv[perm] = torch.arange(perm.numel(),
725+
device=perm.device,
726+
dtype=perm.dtype)
727+
return inv
728+
720729
def forward(
721730
self,
722731
x: torch.Tensor,
@@ -760,6 +769,8 @@ def forward(
760769

761770
rotary_pos_emb = torch.cat(rotary_pos_emb)
762771
window_index = torch.cat(window_index)
772+
# compute reverse indices
773+
reverse_indices = self.invert_permutation(window_index)
763774
cu_window_seqlens = torch.cat(cu_window_seqlens)
764775
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
765776
cu_seqlens = torch.cat(cu_seqlens)
@@ -813,7 +824,6 @@ def forward(
813824

814825
# adapter
815826
hidden_states = self.merger(hidden_states)
816-
reverse_indices = torch.argsort(window_index)
817827
hidden_states = hidden_states[reverse_indices, :]
818828
return hidden_states
819829

0 commit comments

Comments
 (0)