Skip to content

Commit 23c0cff

Browse files
committed
fix
1 parent 261d692 commit 23c0cff

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,16 @@ jobs:
9494
- name: run tests bypassed
9595
run: PYTHONPATH=. python _unittests/ut_torch_models/test_tiny_llms_bypassed.py
9696

97+
- name: test image_classification
98+
run: PYTHONPATH=. python _unittests/ut_tasks/test_tasks_image_classification.py
99+
97100
- name: test zero_shot_image_classification
98101
run: PYTHONPATH=. python _unittests/ut_tasks/test_tasks_zero_shot_image_classification.py
99102

100103
- name: run tests
101104
run: |
102105
pip install pytest
103-
PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py --ignore _unittests/ut_torch_models/test_tiny_llms_bypassed.py --ignore _unittests/ut_tasks/test_tasks_zero_shot_image_classification.py
106+
PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py --ignore _unittests/ut_torch_models/test_tiny_llms_bypassed.py --ignore _unittests/ut_tasks/test_tasks_zero_shot_image_classification.py --ignore _unittests/ut_tasks/test_tasks_image_classification.py
104107
105108
- name: run backend tests python
106109
run: PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_extended_reference_evaluator.py

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class patched_AttentionMaskConverter:
5252
"""
5353

5454
# This method was fixed in 4.51 at least.
55-
_PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.50.9999") else []
55+
_PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else []
5656
_PATCHED_CLASS_ = AttentionMaskConverter
5757

5858
@staticmethod
@@ -71,6 +71,9 @@ def _make_causal_mask(
7171
This static method may be called with ``AttentionMaskConverter._make_causal_mask``
7272
or ``self._make_causal_mask``. That changes this argument is receives.
7373
That should not matter but...
74+
The patch should be implemented in another way. static methods do not play well
75+
with a simple replacement.
76+
Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
7477
"""
7578
if args:
7679
index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1

0 commit comments

Comments
 (0)