Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python: ['3.10', '3.11', '3.12', '3.13']
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.3', '4.54.0', 'main']
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.3', '4.55.0', 'main']
torch: ['2.7', 'main']
exclude:
- python: '3.10'
Expand All @@ -28,15 +28,15 @@ jobs:
- python: '3.10'
transformers: '4.53.3'
- python: '3.10'
transformers: '4.54.0'
transformers: '4.55.0'
- python: '3.11'
torch: 'main'
- python: '3.11'
transformers: '4.53.3'
- python: '3.11'
transformers: 'main'
- python: '3.11'
transformers: '4.54.0'
transformers: '4.55.0'
- python: '3.13'
torch: '2.7'
- python: '3.13'
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.7.7
+++++

* :pr:`196`: implements a patch to rewrite a loop in modeling_qwen2_vl.VisionAttention

0.7.6
+++++

Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_falcon_mamba_dev(self):
model(**inputs)
model(**data["inputs2"])
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
if not has_transformers("4.55"):
if not has_transformers("4.56"):
raise unittest.SkipTest("The model has control flow.")
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
torch.export.export(
Expand Down
8 changes: 7 additions & 1 deletion _unittests/ut_tasks/test_tasks_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@ def test_image_text_to_text_idefics(self):
)

@hide_stdout()
@requires_transformers("4.53")
@requires_transformers("4.56")
@requires_torch("2.7.99")
def test_image_text_to_text_gemma3(self):
"""
If the model tails because of
``if inputs_embeds[special_image_mask].numel() != image_features.numel():```,
make sure this PR was merged:
https://github.com/huggingface/transformers/pull/39962.
"""
# mid = "google/gemma-3-4b-it"
mid = "tiny-random/gemma-3"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
Expand Down
38 changes: 38 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_rewriting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
rewrite_loop_for_square_mask,
)


class TestPatchRewriting(ExtTestCase):
def test_rewrite_loop_for_square_mask(self):
import torch

seq_length = 8
dtype = torch.float32
mask = torch.full([1, seq_length, seq_length], 1, dtype=dtype)

def apply_mask(mask, seq):
mask = mask.clone()
for i in range(1, len(seq)):
mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
return mask

for seqi in [
[1, 5, 8],
[1, 5, 7],
[2, 3, 6],
[2, 3, 3, 6],
[0, 1, 4, 5],
[0, 0, 5, 6],
]:
with self.subTest(seq=seqi):
seq = torch.tensor(seqi, dtype=torch.int64)
m1 = apply_mask(mask, seq)
m2 = rewrite_loop_for_square_mask(mask, seq)
self.assertEqualArray(m1, m2)


if __name__ == "__main__":
unittest.main(verbosity=2)
32 changes: 24 additions & 8 deletions onnx_diagnostic/tasks/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
"hidden_size",
"pad_token_id",
)
check_hasattr(config, "vision_config", "image_token_index")
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
text_config = True
else:
check_hasattr(
Expand All @@ -348,7 +348,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
"vision_config",
)
text_config = False
check_hasattr(config.vision_config, "image_size", "num_channels")
check_hasattr(config.vision_config, ("num_channels", "in_chans"))
kwargs = dict(
batch_size=2,
sequence_length=43,
Expand Down Expand Up @@ -410,18 +410,34 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
if config is None
else (config.text_config.hidden_size if text_config else config.hidden_size)
),
width=224 if config is None else config.vision_config.image_size,
height=224 if config is None else config.vision_config.image_size,
num_channels=3 if config is None else config.vision_config.num_channels,
width=(
224
if config is None or not hasattr(config.vision_config, "image_size")
else config.vision_config.image_size
),
height=(
224
if config is None or not hasattr(config.vision_config, "image_size")
else config.vision_config.image_size
),
num_channels=(
3 if config is None else _pick(config.vision_config, "num_channels", "in_chans")
),
pad_token_id=(
0
if config is None or not hasattr(config, "text_config")
if config is None
or not hasattr(config, "text_config")
or not hasattr(config.text_config, "pad_token_id")
else config.text_config.pad_token_id
),
image_token_index=(
4
if config is None or not hasattr(config, "image_token_index")
else config.image_token_index
if config is None
or (
not hasattr(config, "image_token_index")
and not hasattr(config, "image_token_id")
)
else _pick(config, "image_token_index", "image_token_id")
),
)
return kwargs, get_inputs
89 changes: 89 additions & 0 deletions onnx_diagnostic/torch_export_patches/patches/patch_transformers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import math
from dataclasses import dataclass
from functools import wraps
from typing import Callable, List, Optional, Tuple
Expand Down Expand Up @@ -1363,3 +1364,91 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
else:
outputs = outputs + (None,) # noqa: RUF005
return outputs


def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
"""
Rewrites the loop in:

.. code-block:: python

attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype
)
for i in range(1, len(seq)):
attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
"""
r = torch.arange(0, mask.shape[-1], dtype=torch.int64)
less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
less = less0.sum(axis=-1, keepdim=True) + 1
sq = less * less.T
look = (
torch.max(seq.min() == 0, less != less.max())
* torch.max(seq.max() == mask.shape[-1], less != less.min())
* less
)
filt = (sq != look**2).to(mask.dtype)
return mask * filt


class patched_VisionAttention(torch.nn.Module):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention

def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = (
self.qkv(hidden_states)
.reshape(seq_length, 3, self.num_heads, -1)
.permute(1, 0, 2, 3)
.unbind(0)
)
if position_embeddings is None:
transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
"The attention layers in this model are transitioning from "
" computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), "
"to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin)."
" In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
q, k, cos, sin
)

attention_mask = torch.full(
[1, seq_length, seq_length],
torch.finfo(q.dtype).min,
device=q.device,
dtype=q.dtype,
)
# for i in range(1, len(cu_seqlens)):
# attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
# cu_seqlens[i - 1] : cu_seqlens[i]] = 0
attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)

q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
Loading