|
1 | 1 | import inspect |
| 2 | +import math |
2 | 3 | from dataclasses import dataclass |
3 | 4 | from functools import wraps |
4 | 5 | from typing import Callable, List, Optional, Tuple |
@@ -1363,3 +1364,91 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings): |
1363 | 1364 | else: |
1364 | 1365 | outputs = outputs + (None,) # noqa: RUF005 |
1365 | 1366 | return outputs |
| 1367 | + |
| 1368 | + |
| 1369 | +def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor): |
| 1370 | + """ |
| 1371 | + Rewrites the loop in: |
| 1372 | +
|
| 1373 | + .. code-block:: python |
| 1374 | +
|
| 1375 | + attention_mask = torch.full( |
| 1376 | + [1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype |
| 1377 | + ) |
| 1378 | + for i in range(1, len(seq)): |
| 1379 | + attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0 |
| 1380 | + """ |
| 1381 | + r = torch.arange(0, mask.shape[-1], dtype=torch.int64) |
| 1382 | + less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64) |
| 1383 | + less = less0.sum(axis=-1, keepdim=True) + 1 |
| 1384 | + sq = less * less.T |
| 1385 | + look = ( |
| 1386 | + torch.max(seq.min() == 0, less != less.max()) |
| 1387 | + * torch.max(seq.max() == mask.shape[-1], less != less.min()) |
| 1388 | + * less |
| 1389 | + ) |
| 1390 | + filt = (sq != look**2).to(mask.dtype) |
| 1391 | + return mask * filt |
| 1392 | + |
| 1393 | + |
| 1394 | +class patched_VisionAttention(torch.nn.Module): |
| 1395 | + _PATCHES_ = ["forward"] |
| 1396 | + _PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention |
| 1397 | + |
| 1398 | + def forward( |
| 1399 | + self, |
| 1400 | + hidden_states: torch.Tensor, |
| 1401 | + cu_seqlens: torch.Tensor, |
| 1402 | + rotary_pos_emb: Optional[torch.Tensor] = None, |
| 1403 | + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| 1404 | + ) -> torch.Tensor: |
| 1405 | + seq_length = hidden_states.shape[0] |
| 1406 | + q, k, v = ( |
| 1407 | + self.qkv(hidden_states) |
| 1408 | + .reshape(seq_length, 3, self.num_heads, -1) |
| 1409 | + .permute(1, 0, 2, 3) |
| 1410 | + .unbind(0) |
| 1411 | + ) |
| 1412 | + if position_embeddings is None: |
| 1413 | + transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once( |
| 1414 | + "The attention layers in this model are transitioning from " |
| 1415 | + " computing the RoPE embeddings internally " |
| 1416 | + "through `rotary_pos_emb` (2D tensor of RoPE theta values), " |
| 1417 | + "to using externally computed " |
| 1418 | + "`position_embeddings` (Tuple of tensors, containing cos and sin)." |
| 1419 | + " In v4.54 `rotary_pos_emb` will be " |
| 1420 | + "removed and `position_embeddings` will be mandatory." |
| 1421 | + ) |
| 1422 | + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| 1423 | + cos = emb.cos() |
| 1424 | + sin = emb.sin() |
| 1425 | + else: |
| 1426 | + cos, sin = position_embeddings |
| 1427 | + q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision( |
| 1428 | + q, k, cos, sin |
| 1429 | + ) |
| 1430 | + |
| 1431 | + attention_mask = torch.full( |
| 1432 | + [1, seq_length, seq_length], |
| 1433 | + torch.finfo(q.dtype).min, |
| 1434 | + device=q.device, |
| 1435 | + dtype=q.dtype, |
| 1436 | + ) |
| 1437 | + # for i in range(1, len(cu_seqlens)): |
| 1438 | + # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], |
| 1439 | + # cu_seqlens[i - 1] : cu_seqlens[i]] = 0 |
| 1440 | + attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens) |
| 1441 | + |
| 1442 | + q = q.transpose(0, 1) |
| 1443 | + k = k.transpose(0, 1) |
| 1444 | + v = v.transpose(0, 1) |
| 1445 | + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) |
| 1446 | + attn_weights = attn_weights + attention_mask |
| 1447 | + attn_weights = torch.nn.functional.softmax( |
| 1448 | + attn_weights, dim=-1, dtype=torch.float32 |
| 1449 | + ).to(q.dtype) |
| 1450 | + attn_output = torch.matmul(attn_weights, v) |
| 1451 | + attn_output = attn_output.transpose(0, 1) |
| 1452 | + attn_output = attn_output.reshape(seq_length, -1) |
| 1453 | + attn_output = self.proj(attn_output) |
| 1454 | + return attn_output |
0 commit comments