Skip to content

Commit 4583e52

Browse files
committed
fix
1 parent 135366d commit 4583e52

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from onnx_diagnostic.ext_test_case import (
88
ExtTestCase,
99
requires_transformers,
10+
requires_torch,
1011
ignore_warnings,
1112
has_onnxscript,
1213
)
@@ -465,6 +466,7 @@ def forward(
465466
)
466467

467468
@requires_transformers("4.99")
469+
@requires_torch("2.9.99")
468470
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
469471
def test_qwen2_5_vl_vision_attention_iteration(self):
470472
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2435,8 +2435,8 @@ def _iteration(start_end, query_states, key_states, value_states):
24352435
cu_seqlens.dtype
24362436
)
24372437
dot = dot.sum(dim=0)
2438-
mask = dot.unsqueeze(1) @ dot.unsqueeze(0)
2439-
bool_mask = mask == dot**2
2438+
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
2439+
bool_mask = mask == 0
24402440
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
24412441

24422442
torch._check(bool_mask.shape[2] == key_states.shape[2])

0 commit comments

Comments
 (0)