Skip to content

Commit 907ee41

Browse files
committed
fix
1 parent 4bec737 commit 907ee41

File tree

2 files changed

+116
-86
lines changed

2 files changed

+116
-86
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,11 @@ def torch_export_patches(
420420
patch_transformers_list, verbose=verbose
421421
)
422422

423-
if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
423+
if (
424+
masking_utils
425+
and patch_transformers_list.patch_masking_utils
426+
and hasattr(masking_utils, "_vmap_for_bhqkv")
427+
):
424428
if verbose:
425429
print(
426430
"[torch_export_patches] patches "
@@ -429,7 +433,11 @@ def torch_export_patches(
429433
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
430434
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
431435

432-
if masking_utils and hasattr(masking_utils, "eager_mask"):
436+
if (
437+
masking_utils
438+
and patch_transformers_list.patch_masking_utils
439+
and hasattr(masking_utils, "eager_mask")
440+
):
433441
if verbose:
434442
print(
435443
"[torch_export_patches] patches "
@@ -548,15 +556,23 @@ def torch_export_patches(
548556
patch_transformers_list, revert_patches_info, verbose=verbose
549557
)
550558

551-
if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
559+
if (
560+
masking_utils
561+
and patch_transformers_list.patch_masking_utils
562+
and hasattr(masking_utils, "_vmap_for_bhqkv")
563+
):
552564
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
553565
if verbose:
554566
print(
555567
"[torch_export_patches] restored "
556568
"transformers.masking_utils._vmap_for_bhqkv"
557569
)
558570

559-
if masking_utils and hasattr(masking_utils, "eager_mask"):
571+
if (
572+
masking_utils
573+
and patch_transformers_list.patch_masking_utils
574+
and hasattr(masking_utils, "eager_mask")
575+
):
560576
f_transformers_eager_mask = masking_utils.eager_mask
561577
masking_utils.eager_mask = f_transformers_eager_mask
562578
if (

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 96 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -7,60 +7,107 @@
77
import transformers
88
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
99
from transformers.cache_utils import StaticCache, Cache, DynamicCache
10-
from transformers.masking_utils import causal_mask_function, sdpa_mask
10+
11+
try:
12+
import transformers.masking_utils
13+
14+
patch_masking_utils = True
15+
except ImportError:
16+
patch_masking_utils = False
17+
1118
from ...ext_test_case import has_transformers
1219
from ...helpers.torch_helper import is_torchdynamo_exporting
1320

1421

15-
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
16-
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
17-
from ...helpers import string_type
18-
19-
dimensions: List[Tuple[Optional[int], ...]] = [
20-
(None, None, None, 0),
21-
(None, None, 0, None),
22-
]
23-
if bh_indices:
24-
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
25-
# reshape
26-
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
27-
dimensions = tuple(reversed(dimensions))
28-
indices = tuple(shape.index(-1) for shape in dimensions)
29-
30-
# unsqueeze
31-
udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]
32-
33-
def vector_mask_function(
34-
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
35-
):
36-
assert len(args) == len(dimensions) == len(udimensions), (
37-
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
38-
f"and udimensions={udimensions}."
39-
)
40-
assert len(indices) == len(args), (
41-
f"Mismatch between args={string_type(args)} and indices={indices}, "
42-
f"they should have the same length."
22+
if patch_masking_utils:
23+
# Introduced in 4.52
24+
from transformers.masking_utils import causal_mask_function, sdpa_mask
25+
26+
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
27+
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
28+
from ...helpers import string_type
29+
30+
dimensions: List[Tuple[Optional[int], ...]] = [
31+
(None, None, None, 0),
32+
(None, None, 0, None),
33+
]
34+
if bh_indices:
35+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
36+
# reshape
37+
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
38+
dimensions = tuple(reversed(dimensions))
39+
indices = tuple(shape.index(-1) for shape in dimensions)
40+
41+
# unsqueeze
42+
udimensions = [
43+
tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions
44+
]
45+
46+
def vector_mask_function(
47+
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
48+
):
49+
assert len(args) == len(dimensions) == len(udimensions), (
50+
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
51+
f"and udimensions={udimensions}."
52+
)
53+
assert len(indices) == len(args), (
54+
f"Mismatch between args={string_type(args)} and indices={indices}, "
55+
f"they should have the same length."
56+
)
57+
for a in args:
58+
assert (
59+
a.ndim == 1
60+
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
61+
torch._check(a.shape[0] > 0)
62+
63+
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
64+
# new_args = [
65+
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
66+
# for a, dims in zip(args, udimensions)
67+
# ]
68+
max_shape = tuple(args[i].shape[0] for i in indices)
69+
# if is_torchdynamo_exporting():
70+
# for a in args:
71+
# # The exporter should export with a dimension > 1
72+
# # to make sure it is dynamic.
73+
# torch._check(a.shape[0] > 1)
74+
expanded_args = [a.expand(max_shape) for a in new_args]
75+
return mask_function(*expanded_args)
76+
77+
return vector_mask_function
78+
79+
def patched_eager_mask(
80+
batch_size: int,
81+
cache_position: torch.Tensor,
82+
kv_length: int,
83+
kv_offset: int = 0,
84+
mask_function: Callable = causal_mask_function,
85+
attention_mask: Optional[torch.Tensor] = None,
86+
dtype: torch.dtype = torch.float32,
87+
**kwargs,
88+
) -> torch.Tensor:
89+
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
90+
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
91+
_ = kwargs.pop("allow_is_causal_skip", None)
92+
mask = sdpa_mask(
93+
batch_size=batch_size,
94+
cache_position=cache_position,
95+
kv_length=kv_length,
96+
kv_offset=kv_offset,
97+
mask_function=mask_function,
98+
attention_mask=attention_mask,
99+
allow_is_causal_skip=False,
100+
allow_torch_fix=False,
101+
**kwargs,
43102
)
44-
for a in args:
45-
assert (
46-
a.ndim == 1
47-
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
48-
torch._check(a.shape[0] > 0)
49-
50-
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
51-
# new_args = [
52-
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
53-
# for a, dims in zip(args, udimensions)
54-
# ]
55-
max_shape = tuple(args[i].shape[0] for i in indices)
56-
# if is_torchdynamo_exporting():
57-
# for a in args:
58-
# # The exporter should export with a dimension > 1 to make sure it is dynamic.
59-
# torch._check(a.shape[0] > 1)
60-
expanded_args = [a.expand(max_shape) for a in new_args]
61-
return mask_function(*expanded_args)
62-
63-
return vector_mask_function
103+
min_dtype = torch.finfo(dtype).min
104+
# The patched line.
105+
# we need 0s where the tokens should be taken into account,
106+
# and -inf otherwise (mask is already of boolean type)
107+
# mask =
108+
# torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
109+
mask = (~mask).to(dtype) * min_dtype
110+
return mask
64111

65112

66113
def _patch_make_causal_mask(
@@ -1047,36 +1094,3 @@ def forward(
10471094
attn_weights = None
10481095

10491096
return attn_output, attn_weights, past_key_value
1050-
1051-
1052-
def patched_eager_mask(
1053-
batch_size: int,
1054-
cache_position: torch.Tensor,
1055-
kv_length: int,
1056-
kv_offset: int = 0,
1057-
mask_function: Callable = causal_mask_function,
1058-
attention_mask: Optional[torch.Tensor] = None,
1059-
dtype: torch.dtype = torch.float32,
1060-
**kwargs,
1061-
) -> torch.Tensor:
1062-
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
1063-
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
1064-
_ = kwargs.pop("allow_is_causal_skip", None)
1065-
mask = sdpa_mask(
1066-
batch_size=batch_size,
1067-
cache_position=cache_position,
1068-
kv_length=kv_length,
1069-
kv_offset=kv_offset,
1070-
mask_function=mask_function,
1071-
attention_mask=attention_mask,
1072-
allow_is_causal_skip=False,
1073-
allow_torch_fix=False,
1074-
**kwargs,
1075-
)
1076-
min_dtype = torch.finfo(dtype).min
1077-
# The patched line.
1078-
# we need 0s where the tokens should be taken into account,
1079-
# and -inf otherwise (mask is already of boolean type)
1080-
# mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
1081-
mask = (~mask).to(dtype) * min_dtype
1082-
return mask

0 commit comments

Comments
 (0)