Skip to content

Commit 6649c9a

Browse files
committed
add patch for qwen
1 parent 67aeb3e commit 6649c9a

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import math
23
from dataclasses import dataclass
34
from functools import wraps
45
from typing import Callable, List, Optional, Tuple
@@ -1388,3 +1389,66 @@ def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
13881389
)
13891390
filt = (sq != look**2).to(mask.dtype)
13901391
return mask * filt
1392+
1393+
1394+
class patched_VisionAttention(torch.nn.Module):
1395+
_PATCHES_ = ["forward"]
1396+
_PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
1397+
1398+
def forward(
1399+
self,
1400+
hidden_states: torch.Tensor,
1401+
cu_seqlens: torch.Tensor,
1402+
rotary_pos_emb: Optional[torch.Tensor] = None,
1403+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1404+
) -> torch.Tensor:
1405+
seq_length = hidden_states.shape[0]
1406+
q, k, v = (
1407+
self.qkv(hidden_states)
1408+
.reshape(seq_length, 3, self.num_heads, -1)
1409+
.permute(1, 0, 2, 3)
1410+
.unbind(0)
1411+
)
1412+
if position_embeddings is None:
1413+
transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
1414+
"The attention layers in this model are transitioning from "
1415+
" computing the RoPE embeddings internally "
1416+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), "
1417+
"to using externally computed "
1418+
"`position_embeddings` (Tuple of tensors, containing cos and sin)."
1419+
" In v4.54 `rotary_pos_emb` will be "
1420+
"removed and `position_embeddings` will be mandatory."
1421+
)
1422+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1423+
cos = emb.cos()
1424+
sin = emb.sin()
1425+
else:
1426+
cos, sin = position_embeddings
1427+
q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
1428+
q, k, cos, sin
1429+
)
1430+
1431+
attention_mask = torch.full(
1432+
[1, seq_length, seq_length],
1433+
torch.finfo(q.dtype).min,
1434+
device=q.device,
1435+
dtype=q.dtype,
1436+
)
1437+
# for i in range(1, len(cu_seqlens)):
1438+
# attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
1439+
# cu_seqlens[i - 1] : cu_seqlens[i]] = 0
1440+
attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
1441+
1442+
q = q.transpose(0, 1)
1443+
k = k.transpose(0, 1)
1444+
v = v.transpose(0, 1)
1445+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
1446+
attn_weights = attn_weights + attention_mask
1447+
attn_weights = torch.nn.functional.softmax(
1448+
attn_weights, dim=-1, dtype=torch.float32
1449+
).to(q.dtype)
1450+
attn_output = torch.matmul(attn_weights, v)
1451+
attn_output = attn_output.transpose(0, 1)
1452+
attn_output = attn_output.reshape(seq_length, -1)
1453+
attn_output = self.proj(attn_output)
1454+
return attn_output

0 commit comments

Comments
 (0)