diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b142e67e..07358585 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.10', '3.11', '3.12', '3.13'] - transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.3', '4.54.0', 'main'] + transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.3', '4.55.0', 'main'] torch: ['2.7', 'main'] exclude: - python: '3.10' @@ -28,7 +28,7 @@ jobs: - python: '3.10' transformers: '4.53.3' - python: '3.10' - transformers: '4.54.0' + transformers: '4.55.0' - python: '3.11' torch: 'main' - python: '3.11' @@ -36,7 +36,7 @@ jobs: - python: '3.11' transformers: 'main' - python: '3.11' - transformers: '4.54.0' + transformers: '4.55.0' - python: '3.13' torch: '2.7' - python: '3.13' diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 86828035..f3129ded 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.7.7 +++++ +* :pr:`196`: implements a patch to rewrite a loop in modeling_qwen2_vl.VisionAttention + 0.7.6 +++++ diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 9ea777c7..3043eed0 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -287,7 +287,7 @@ def test_falcon_mamba_dev(self): model(**inputs) model(**data["inputs2"]) self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)]) - if not has_transformers("4.55"): + if not has_transformers("4.56"): raise unittest.SkipTest("The model has control flow.") with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1): torch.export.export( diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index a17c52f0..360caf3f 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -30,9 +30,15 @@ def test_image_text_to_text_idefics(self): ) @hide_stdout() - @requires_transformers("4.53") + @requires_transformers("4.56") @requires_torch("2.7.99") def test_image_text_to_text_gemma3(self): + """ + If the model tails because of + ``if inputs_embeds[special_image_mask].numel() != image_features.numel():```, + make sure this PR was merged: + https://github.com/huggingface/transformers/pull/39962. + """ # mid = "google/gemma-3-4b-it" mid = "tiny-random/gemma-3" data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) diff --git a/_unittests/ut_torch_export_patches/test_patch_rewriting.py b/_unittests/ut_torch_export_patches/test_patch_rewriting.py new file mode 100644 index 00000000..7c8a0f8e --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_rewriting.py @@ -0,0 +1,38 @@ +import unittest +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( + rewrite_loop_for_square_mask, +) + + +class TestPatchRewriting(ExtTestCase): + def test_rewrite_loop_for_square_mask(self): + import torch + + seq_length = 8 + dtype = torch.float32 + mask = torch.full([1, seq_length, seq_length], 1, dtype=dtype) + + def apply_mask(mask, seq): + mask = mask.clone() + for i in range(1, len(seq)): + mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0 + return mask + + for seqi in [ + [1, 5, 8], + [1, 5, 7], + [2, 3, 6], + [2, 3, 3, 6], + [0, 1, 4, 5], + [0, 0, 5, 6], + ]: + with self.subTest(seq=seqi): + seq = torch.tensor(seqi, dtype=torch.int64) + m1 = apply_mask(mask, seq) + m2 = rewrite_loop_for_square_mask(mask, seq) + self.assertEqualArray(m1, m2) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index a76451e1..b36e6036 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -334,7 +334,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: "hidden_size", "pad_token_id", ) - check_hasattr(config, "vision_config", "image_token_index") + check_hasattr(config, "vision_config", ("image_token_index", "image_token_id")) text_config = True else: check_hasattr( @@ -348,7 +348,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: "vision_config", ) text_config = False - check_hasattr(config.vision_config, "image_size", "num_channels") + check_hasattr(config.vision_config, ("num_channels", "in_chans")) kwargs = dict( batch_size=2, sequence_length=43, @@ -410,18 +410,34 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: if config is None else (config.text_config.hidden_size if text_config else config.hidden_size) ), - width=224 if config is None else config.vision_config.image_size, - height=224 if config is None else config.vision_config.image_size, - num_channels=3 if config is None else config.vision_config.num_channels, + width=( + 224 + if config is None or not hasattr(config.vision_config, "image_size") + else config.vision_config.image_size + ), + height=( + 224 + if config is None or not hasattr(config.vision_config, "image_size") + else config.vision_config.image_size + ), + num_channels=( + 3 if config is None else _pick(config.vision_config, "num_channels", "in_chans") + ), pad_token_id=( 0 - if config is None or not hasattr(config, "text_config") + if config is None + or not hasattr(config, "text_config") + or not hasattr(config.text_config, "pad_token_id") else config.text_config.pad_token_id ), image_token_index=( 4 - if config is None or not hasattr(config, "image_token_index") - else config.image_token_index + if config is None + or ( + not hasattr(config, "image_token_index") + and not hasattr(config, "image_token_id") + ) + else _pick(config, "image_token_index", "image_token_id") ), ) return kwargs, get_inputs diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index af944ea1..63f51aad 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -1,4 +1,5 @@ import inspect +import math from dataclasses import dataclass from functools import wraps from typing import Callable, List, Optional, Tuple @@ -1363,3 +1364,91 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings): else: outputs = outputs + (None,) # noqa: RUF005 return outputs + + +def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor): + """ + Rewrites the loop in: + + .. code-block:: python + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype + ) + for i in range(1, len(seq)): + attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0 + """ + r = torch.arange(0, mask.shape[-1], dtype=torch.int64) + less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64) + less = less0.sum(axis=-1, keepdim=True) + 1 + sq = less * less.T + look = ( + torch.max(seq.min() == 0, less != less.max()) + * torch.max(seq.max() == mask.shape[-1], less != less.min()) + * less + ) + filt = (sq != look**2).to(mask.dtype) + return mask * filt + + +class patched_VisionAttention(torch.nn.Module): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states) + .reshape(seq_length, 3, self.num_heads, -1) + .permute(1, 0, 2, 3) + .unbind(0) + ) + if position_embeddings is None: + transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once( + "The attention layers in this model are transitioning from " + " computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), " + "to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin)." + " In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision( + q, k, cos, sin + ) + + attention_mask = torch.full( + [1, seq_length, seq_length], + torch.finfo(q.dtype).min, + device=q.device, + dtype=q.dtype, + ) + # for i in range(1, len(cu_seqlens)): + # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], + # cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens) + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output