Skip to content

Commit fc34f80

Browse files
committed
patch
1 parent 3f53a64 commit fc34f80

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,14 @@ def torch_export_patches(
350350
patch_transformers_list, verbose=verbose
351351
)
352352

353-
if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
354-
if verbose:
355-
print(
356-
"[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv"
357-
)
358-
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
359-
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
353+
if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
354+
if verbose:
355+
print(
356+
"[torch_export_patches] patches "
357+
"transformers.masking_utils._vmap_for_bhqkv"
358+
)
359+
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
360+
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
360361

361362
if custom_patches:
362363
if verbose:
@@ -450,6 +451,10 @@ def torch_export_patches(
450451
##############
451452

452453
if patch_transformers:
454+
try:
455+
import transformers.masking_utils as masking_utils
456+
except ImportError:
457+
masking_utils = None
453458
if verbose:
454459
print("[torch_export_patches] unpatch transformers")
455460
unpatch_module_or_classes(

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) ->
1313
"""Patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
1414
from ...helpers import string_type
1515

16-
dimensions: List[Optional[int]] = [(None, None, None, 0), (None, None, 0, None)]
16+
dimensions: List[Tuple[Optional[int], ...]] = [
17+
(None, None, None, 0),
18+
(None, None, 0, None),
19+
]
1720
if bh_indices:
1821
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
1922
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]

0 commit comments

Comments
 (0)