Skip to content

Commit 62147bc

Browse files
authored
implements a patch to rewrite a loop (#196)
* implements a patch to rewrite a loop * ruff * swtich to 4.55 * 4.56 * fix issues * qwen... * add patch for qwen
1 parent e4d73e6 commit 62147bc

File tree

7 files changed

+164
-13
lines changed

7 files changed

+164
-13
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
os: [ubuntu-latest]
1818
python: ['3.10', '3.11', '3.12', '3.13']
19-
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.3', '4.54.0', 'main']
19+
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.3', '4.55.0', 'main']
2020
torch: ['2.7', 'main']
2121
exclude:
2222
- python: '3.10'
@@ -28,15 +28,15 @@ jobs:
2828
- python: '3.10'
2929
transformers: '4.53.3'
3030
- python: '3.10'
31-
transformers: '4.54.0'
31+
transformers: '4.55.0'
3232
- python: '3.11'
3333
torch: 'main'
3434
- python: '3.11'
3535
transformers: '4.53.3'
3636
- python: '3.11'
3737
transformers: 'main'
3838
- python: '3.11'
39-
transformers: '4.54.0'
39+
transformers: '4.55.0'
4040
- python: '3.13'
4141
torch: '2.7'
4242
- python: '3.13'

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.7
55
+++++
66

7+
* :pr:`196`: implements a patch to rewrite a loop in modeling_qwen2_vl.VisionAttention
8+
79
0.7.6
810
+++++
911

_unittests/ut_tasks/test_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def test_falcon_mamba_dev(self):
287287
model(**inputs)
288288
model(**data["inputs2"])
289289
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
290-
if not has_transformers("4.55"):
290+
if not has_transformers("4.56"):
291291
raise unittest.SkipTest("The model has control flow.")
292292
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
293293
torch.export.export(

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,15 @@ def test_image_text_to_text_idefics(self):
3030
)
3131

3232
@hide_stdout()
33-
@requires_transformers("4.53")
33+
@requires_transformers("4.56")
3434
@requires_torch("2.7.99")
3535
def test_image_text_to_text_gemma3(self):
36+
"""
37+
If the model tails because of
38+
``if inputs_embeds[special_image_mask].numel() != image_features.numel():```,
39+
make sure this PR was merged:
40+
https://github.com/huggingface/transformers/pull/39962.
41+
"""
3642
# mid = "google/gemma-3-4b-it"
3743
mid = "tiny-random/gemma-3"
3844
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
4+
rewrite_loop_for_square_mask,
5+
)
6+
7+
8+
class TestPatchRewriting(ExtTestCase):
9+
def test_rewrite_loop_for_square_mask(self):
10+
import torch
11+
12+
seq_length = 8
13+
dtype = torch.float32
14+
mask = torch.full([1, seq_length, seq_length], 1, dtype=dtype)
15+
16+
def apply_mask(mask, seq):
17+
mask = mask.clone()
18+
for i in range(1, len(seq)):
19+
mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
20+
return mask
21+
22+
for seqi in [
23+
[1, 5, 8],
24+
[1, 5, 7],
25+
[2, 3, 6],
26+
[2, 3, 3, 6],
27+
[0, 1, 4, 5],
28+
[0, 0, 5, 6],
29+
]:
30+
with self.subTest(seq=seqi):
31+
seq = torch.tensor(seqi, dtype=torch.int64)
32+
m1 = apply_mask(mask, seq)
33+
m2 = rewrite_loop_for_square_mask(mask, seq)
34+
self.assertEqualArray(m1, m2)
35+
36+
37+
if __name__ == "__main__":
38+
unittest.main(verbosity=2)

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
334334
"hidden_size",
335335
"pad_token_id",
336336
)
337-
check_hasattr(config, "vision_config", "image_token_index")
337+
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
338338
text_config = True
339339
else:
340340
check_hasattr(
@@ -348,7 +348,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
348348
"vision_config",
349349
)
350350
text_config = False
351-
check_hasattr(config.vision_config, "image_size", "num_channels")
351+
check_hasattr(config.vision_config, ("num_channels", "in_chans"))
352352
kwargs = dict(
353353
batch_size=2,
354354
sequence_length=43,
@@ -410,18 +410,34 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
410410
if config is None
411411
else (config.text_config.hidden_size if text_config else config.hidden_size)
412412
),
413-
width=224 if config is None else config.vision_config.image_size,
414-
height=224 if config is None else config.vision_config.image_size,
415-
num_channels=3 if config is None else config.vision_config.num_channels,
413+
width=(
414+
224
415+
if config is None or not hasattr(config.vision_config, "image_size")
416+
else config.vision_config.image_size
417+
),
418+
height=(
419+
224
420+
if config is None or not hasattr(config.vision_config, "image_size")
421+
else config.vision_config.image_size
422+
),
423+
num_channels=(
424+
3 if config is None else _pick(config.vision_config, "num_channels", "in_chans")
425+
),
416426
pad_token_id=(
417427
0
418-
if config is None or not hasattr(config, "text_config")
428+
if config is None
429+
or not hasattr(config, "text_config")
430+
or not hasattr(config.text_config, "pad_token_id")
419431
else config.text_config.pad_token_id
420432
),
421433
image_token_index=(
422434
4
423-
if config is None or not hasattr(config, "image_token_index")
424-
else config.image_token_index
435+
if config is None
436+
or (
437+
not hasattr(config, "image_token_index")
438+
and not hasattr(config, "image_token_id")
439+
)
440+
else _pick(config, "image_token_index", "image_token_id")
425441
),
426442
)
427443
return kwargs, get_inputs

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import math
23
from dataclasses import dataclass
34
from functools import wraps
45
from typing import Callable, List, Optional, Tuple
@@ -1363,3 +1364,91 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
13631364
else:
13641365
outputs = outputs + (None,) # noqa: RUF005
13651366
return outputs
1367+
1368+
1369+
def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
1370+
"""
1371+
Rewrites the loop in:
1372+
1373+
.. code-block:: python
1374+
1375+
attention_mask = torch.full(
1376+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype
1377+
)
1378+
for i in range(1, len(seq)):
1379+
attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
1380+
"""
1381+
r = torch.arange(0, mask.shape[-1], dtype=torch.int64)
1382+
less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
1383+
less = less0.sum(axis=-1, keepdim=True) + 1
1384+
sq = less * less.T
1385+
look = (
1386+
torch.max(seq.min() == 0, less != less.max())
1387+
* torch.max(seq.max() == mask.shape[-1], less != less.min())
1388+
* less
1389+
)
1390+
filt = (sq != look**2).to(mask.dtype)
1391+
return mask * filt
1392+
1393+
1394+
class patched_VisionAttention(torch.nn.Module):
1395+
_PATCHES_ = ["forward"]
1396+
_PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
1397+
1398+
def forward(
1399+
self,
1400+
hidden_states: torch.Tensor,
1401+
cu_seqlens: torch.Tensor,
1402+
rotary_pos_emb: Optional[torch.Tensor] = None,
1403+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1404+
) -> torch.Tensor:
1405+
seq_length = hidden_states.shape[0]
1406+
q, k, v = (
1407+
self.qkv(hidden_states)
1408+
.reshape(seq_length, 3, self.num_heads, -1)
1409+
.permute(1, 0, 2, 3)
1410+
.unbind(0)
1411+
)
1412+
if position_embeddings is None:
1413+
transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
1414+
"The attention layers in this model are transitioning from "
1415+
" computing the RoPE embeddings internally "
1416+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), "
1417+
"to using externally computed "
1418+
"`position_embeddings` (Tuple of tensors, containing cos and sin)."
1419+
" In v4.54 `rotary_pos_emb` will be "
1420+
"removed and `position_embeddings` will be mandatory."
1421+
)
1422+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1423+
cos = emb.cos()
1424+
sin = emb.sin()
1425+
else:
1426+
cos, sin = position_embeddings
1427+
q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
1428+
q, k, cos, sin
1429+
)
1430+
1431+
attention_mask = torch.full(
1432+
[1, seq_length, seq_length],
1433+
torch.finfo(q.dtype).min,
1434+
device=q.device,
1435+
dtype=q.dtype,
1436+
)
1437+
# for i in range(1, len(cu_seqlens)):
1438+
# attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
1439+
# cu_seqlens[i - 1] : cu_seqlens[i]] = 0
1440+
attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
1441+
1442+
q = q.transpose(0, 1)
1443+
k = k.transpose(0, 1)
1444+
v = v.transpose(0, 1)
1445+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
1446+
attn_weights = attn_weights + attention_mask
1447+
attn_weights = torch.nn.functional.softmax(
1448+
attn_weights, dim=-1, dtype=torch.float32
1449+
).to(q.dtype)
1450+
attn_output = torch.matmul(attn_weights, v)
1451+
attn_output = attn_output.transpose(0, 1)
1452+
attn_output = attn_output.reshape(seq_length, -1)
1453+
attn_output = self.proj(attn_output)
1454+
return attn_output

0 commit comments

Comments
 (0)