Skip to content

Commit 5a87076

Browse files
lgeigerIsotr0py
andauthored
[Model][QwenVL] Optimize Qwen2_5_VisionAttention q,k preparation (#28769)
Signed-off-by: Lukas Geiger <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent ac1daf3 commit 5a87076

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

vllm/model_executor/models/dots_ocr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
)
4040
from vllm.model_executor.models.module_mapping import MultiModelKeys
4141
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
42-
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
4342
from vllm.model_executor.models.qwen2_vl import (
43+
Qwen2VisionAttention,
4444
Qwen2VLDummyInputsBuilder,
4545
Qwen2VLMultiModalProcessor,
4646
Qwen2VLProcessingInfo,
@@ -328,7 +328,7 @@ def forward(
328328
# [S, C] -> [S, B=1, C]
329329
x = hidden_states.unsqueeze(1)
330330
x, _ = self.qkv(x)
331-
q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x)
331+
q, k, v = Qwen2VisionAttention.split_qkv(self, x)
332332
bs = q.shape[1]
333333
# [S,B,H,D] -> [B,S,H,D]
334334
q = q.permute(1, 0, 2, 3).contiguous()

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -359,23 +359,6 @@ def __init__(
359359
AttentionBackendEnum.ROCM_AITER_FA,
360360
}
361361

362-
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
363-
# [s, b, 3 * head * head_dim]
364-
seq_len, bs, _ = qkv.shape
365-
366-
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
367-
q, k, v = qkv.chunk(3, dim=2)
368-
369-
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
370-
new_shape = (
371-
seq_len,
372-
bs,
373-
self.num_attention_heads_per_partition,
374-
self.hidden_size_per_attention_head,
375-
)
376-
q, k, v = (x.view(*new_shape) for x in (q, k, v))
377-
return q, k, v
378-
379362
def forward(
380363
self,
381364
x: torch.Tensor,
@@ -386,17 +369,32 @@ def forward(
386369
) -> torch.Tensor:
387370
# [s, b, c] --> [s, b, head * 3 * head_dim]
388371
x, _ = self.qkv(x)
372+
seq_len, batch_size, _ = x.shape
389373

390-
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
391-
q, k, v = self.split_qkv(x)
392-
batch_size = q.shape[1]
374+
qkv = einops.rearrange(
375+
x,
376+
"s b (three head head_dim) -> b s three head head_dim",
377+
three=3,
378+
head=self.num_attention_heads_per_partition,
379+
)
393380

394-
q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
395381
if rotary_pos_emb is not None:
396-
# [2 * b, s, heads, head_dim]
397-
qk_concat = torch.cat([q, k], dim=0)
398-
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
399-
q, k = torch.chunk(qk_rotated, 2, dim=0)
382+
qk, v = qkv[:, :, :2], qkv[:, :, 2]
383+
384+
qk_reshaped = einops.rearrange(
385+
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
386+
)
387+
qk_rotated = apply_rotary_pos_emb_vision(qk_reshaped, rotary_pos_emb)
388+
qk_rotated = qk_rotated.view(
389+
2,
390+
batch_size,
391+
seq_len,
392+
self.num_attention_heads_per_partition,
393+
self.hidden_size_per_attention_head,
394+
)
395+
q, k = qk_rotated.unbind(dim=0)
396+
else:
397+
q, k, v = qkv.unbind(dim=2)
400398

401399
if self.is_flash_attn_backend:
402400
context_layer = vit_flash_attn_wrapper(

0 commit comments

Comments
 (0)