Skip to content

Commit 31e6d42

Browse files
committed
fix
1 parent 8cadd99 commit 31e6d42

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import unittest
44
from typing import Any, Dict, List, Tuple
55
import torch
6-
import transformers
6+
7+
try:
8+
import transformers.masking_utils as masking_utils
9+
except ImportError:
10+
masking_utils = None
711
from onnx_diagnostic.ext_test_case import (
812
ExtTestCase,
913
ignore_warnings,
@@ -322,6 +326,7 @@ def test_phi2_export_module(self):
322326

323327
@ignore_warnings(UserWarning)
324328
@requires_torch("2.9")
329+
@hide_stdout()
325330
def test_phi2_export_interpreter(self):
326331
data = get_untrained_model_with_inputs("microsoft/phi-2")
327332
model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"]
@@ -342,10 +347,11 @@ def test_phi2_export_interpreter(self):
342347
)
343348

344349
with torch_export_patches(patch_transformers=True, verbose=1):
345-
self.assertEqual(
346-
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"],
347-
patch_transformers.patched_sdpa_mask_recent_torch,
348-
)
350+
if masking_utils is not None:
351+
self.assertEqual(
352+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"],
353+
patch_transformers.patched_sdpa_mask_recent_torch,
354+
)
349355
ep = torch.export.export(
350356
model,
351357
(),

0 commit comments

Comments
 (0)