Skip to content

Commit 69df2e4

Browse files
committed
fix import issues
1 parent cf7dc28 commit 69df2e4

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@
1919
convert_dynamic_axes_into_dynamic_shapes,
2020
)
2121
from onnx_diagnostic.torch_export_patches import torch_export_patches
22-
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
23-
patched__vmap_for_bhqkv,
24-
)
22+
23+
try:
24+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
25+
patched__vmap_for_bhqkv,
26+
)
27+
except ImportError:
28+
patched__vmap_for_bhqkv = None
2529

2630

2731
class TestCacheHelpers(ExtTestCase):
@@ -258,6 +262,7 @@ def test_unflatten_flatten_hybrid_cache(self):
258262
self.string_type(unflat, with_shape=True),
259263
)
260264

265+
@unittest.skipIf(patched__vmap_for_bhqkv is None, "transformers too old")
261266
def test_cache_update_padding_mask_function_vmap(self):
262267
def causal_mask_function(
263268
batch_idx: int, head_idx: int, q_idx: int, kv_idx: int

0 commit comments

Comments
 (0)