Skip to content

Commit 62742ed

Browse files
committed
change
1 parent 75456d7 commit 62742ed

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

_unittests/ut_tasks/test_tasks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def test_image_classification(self):
4545
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
4646
model(**inputs)
4747
model(**data["inputs2"])
48+
if not has_transformers("4.51"):
49+
raise unittest.SkipTest("_patch_make_causal_mask patch fails when an issue arises")
4850
with bypass_export_some_errors(patch_transformers=True, verbose=10):
4951
torch.export.export(
5052
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
@@ -131,7 +133,7 @@ def test_image_text_to_text(self):
131133
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
132134
model(**inputs)
133135
model(**data["inputs2"])
134-
if not has_torch("2.8"):
136+
if not has_torch("2.7.99"):
135137
raise unittest.SkipTest("sym_max does not work with dynamic dimension")
136138
with bypass_export_some_errors(patch_transformers=True, verbose=10):
137139
torch.export.export(
@@ -219,6 +221,8 @@ def test_zero_shot_image_classification(self):
219221
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
220222
model(**inputs)
221223
model(**data["inputs2"])
224+
if not has_transformers("4.51"):
225+
raise unittest.SkipTest("_patch_make_causal_mask patch fails when an issue arises")
222226
with bypass_export_some_errors(patch_transformers=True, verbose=10):
223227
torch.export.export(
224228
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _catch_produce_guards_and_solve_constraints(
4444
raise
4545
if verbose:
4646
print(
47-
f"[_catch_produce_guards_and_solve_constraints] ERROR"
47+
f"[_catch_produce_guards_and_solve_constraints] ERROR: "
4848
f"produce_guards_and_solve_constraints failed, "
4949
f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n"
5050
f"fake_mode={fake_mode}\n"
@@ -54,6 +54,7 @@ def _catch_produce_guards_and_solve_constraints(
5454
f"_is_torch_jit_trace={_is_torch_jit_trace}\n"
5555
f"exc={e}\ngm={gm}"
5656
)
57+
torch._dynamo.reset()
5758

5859

5960
def patch__check_input_constraints_for_graph(
@@ -70,13 +71,14 @@ def patch__check_input_constraints_for_graph(
7071
raise
7172
if verbose:
7273
print(
73-
f"[_check_input_constraints_for_graph] ERROR"
74+
f"[_check_input_constraints_for_graph] ERROR: "
7475
f"_check_input_constraints_for_graph failed, "
7576
f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n"
7677
f"input_placeholders={input_placeholders}\n"
7778
f"range_constraints={range_constraints}\n"
7879
f"exc={e}"
7980
)
81+
torch._dynamo.reset()
8082

8183

8284
def patched_infer_size(a, b):

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import transformers
66
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
77
from transformers.cache_utils import StaticCache, Cache, DynamicCache
8+
from ...ext_test_case import has_transformers
89
from ...helpers.torch_test_helper import is_torchdynamo_exporting
910

1011

@@ -50,7 +51,8 @@ class patched_AttentionMaskConverter:
5051
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
5152
"""
5253

53-
_PATCHES_ = ["_make_causal_mask"]
54+
# This method was fixed in 4.51 at least.
55+
_PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.50.9999") else []
5456
_PATCHED_CLASS_ = AttentionMaskConverter
5557

5658
@staticmethod

0 commit comments

Comments
 (0)