|
1 | 1 | import inspect |
| 2 | +import math |
2 | 3 | from dataclasses import dataclass |
3 | 4 | from functools import wraps |
4 | 5 | from typing import Callable, List, Optional, Tuple |
@@ -1388,3 +1389,66 @@ def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor): |
1388 | 1389 | ) |
1389 | 1390 | filt = (sq != look**2).to(mask.dtype) |
1390 | 1391 | return mask * filt |
| 1392 | + |
| 1393 | + |
| 1394 | +class patched_VisionAttention(torch.nn.Module): |
| 1395 | + _PATCHES_ = ["forward"] |
| 1396 | + _PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention |
| 1397 | + |
| 1398 | + def forward( |
| 1399 | + self, |
| 1400 | + hidden_states: torch.Tensor, |
| 1401 | + cu_seqlens: torch.Tensor, |
| 1402 | + rotary_pos_emb: Optional[torch.Tensor] = None, |
| 1403 | + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| 1404 | + ) -> torch.Tensor: |
| 1405 | + seq_length = hidden_states.shape[0] |
| 1406 | + q, k, v = ( |
| 1407 | + self.qkv(hidden_states) |
| 1408 | + .reshape(seq_length, 3, self.num_heads, -1) |
| 1409 | + .permute(1, 0, 2, 3) |
| 1410 | + .unbind(0) |
| 1411 | + ) |
| 1412 | + if position_embeddings is None: |
| 1413 | + transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once( |
| 1414 | + "The attention layers in this model are transitioning from " |
| 1415 | + " computing the RoPE embeddings internally " |
| 1416 | + "through `rotary_pos_emb` (2D tensor of RoPE theta values), " |
| 1417 | + "to using externally computed " |
| 1418 | + "`position_embeddings` (Tuple of tensors, containing cos and sin)." |
| 1419 | + " In v4.54 `rotary_pos_emb` will be " |
| 1420 | + "removed and `position_embeddings` will be mandatory." |
| 1421 | + ) |
| 1422 | + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| 1423 | + cos = emb.cos() |
| 1424 | + sin = emb.sin() |
| 1425 | + else: |
| 1426 | + cos, sin = position_embeddings |
| 1427 | + q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision( |
| 1428 | + q, k, cos, sin |
| 1429 | + ) |
| 1430 | + |
| 1431 | + attention_mask = torch.full( |
| 1432 | + [1, seq_length, seq_length], |
| 1433 | + torch.finfo(q.dtype).min, |
| 1434 | + device=q.device, |
| 1435 | + dtype=q.dtype, |
| 1436 | + ) |
| 1437 | + # for i in range(1, len(cu_seqlens)): |
| 1438 | + # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], |
| 1439 | + # cu_seqlens[i - 1] : cu_seqlens[i]] = 0 |
| 1440 | + attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens) |
| 1441 | + |
| 1442 | + q = q.transpose(0, 1) |
| 1443 | + k = k.transpose(0, 1) |
| 1444 | + v = v.transpose(0, 1) |
| 1445 | + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) |
| 1446 | + attn_weights = attn_weights + attention_mask |
| 1447 | + attn_weights = torch.nn.functional.softmax( |
| 1448 | + attn_weights, dim=-1, dtype=torch.float32 |
| 1449 | + ).to(q.dtype) |
| 1450 | + attn_output = torch.matmul(attn_weights, v) |
| 1451 | + attn_output = attn_output.transpose(0, 1) |
| 1452 | + attn_output = attn_output.reshape(seq_length, -1) |
| 1453 | + attn_output = self.proj(attn_output) |
| 1454 | + return attn_output |
0 commit comments