Skip to content

Commit a9c7439

Browse files
committed
ruff
1 parent 6d8670b commit a9c7439

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.7
55
+++++
66

7+
* :pr:`196`: implements a patch to rewrite a loop in modeling_qwen2_vl.VisionAttention
8+
79
0.7.6
810
+++++
911

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,11 +1381,9 @@ def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
13811381
less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
13821382
less = less0.sum(axis=-1, keepdim=True) + 1
13831383
sq = less * less.T
1384-
less_min = less.min()
1385-
less_max = less.max()
13861384
look = (
13871385
torch.max(seq.min() == 0, less != less.max())
1388-
* torch.max(seq.max() == mask.shape[-1], less != less_min)
1386+
* torch.max(seq.max() == mask.shape[-1], less != less.min())
13891387
* less
13901388
)
13911389
filt = (sq != look**2).to(mask.dtype)

0 commit comments

Comments
 (0)