|
7 | 7 | import transformers |
8 | 8 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
9 | 9 | from transformers.cache_utils import StaticCache, Cache, DynamicCache |
10 | | -from transformers.masking_utils import causal_mask_function, sdpa_mask |
| 10 | + |
| 11 | +try: |
| 12 | + import transformers.masking_utils |
| 13 | + |
| 14 | + patch_masking_utils = True |
| 15 | +except ImportError: |
| 16 | + patch_masking_utils = False |
| 17 | + |
11 | 18 | from ...ext_test_case import has_transformers |
12 | 19 | from ...helpers.torch_helper import is_torchdynamo_exporting |
13 | 20 |
|
14 | 21 |
|
15 | | -def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: |
16 | | - """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" |
17 | | - from ...helpers import string_type |
18 | | - |
19 | | - dimensions: List[Tuple[Optional[int], ...]] = [ |
20 | | - (None, None, None, 0), |
21 | | - (None, None, 0, None), |
22 | | - ] |
23 | | - if bh_indices: |
24 | | - dimensions.extend([(None, 0, None, None), (0, None, None, None)]) |
25 | | - # reshape |
26 | | - dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions] |
27 | | - dimensions = tuple(reversed(dimensions)) |
28 | | - indices = tuple(shape.index(-1) for shape in dimensions) |
29 | | - |
30 | | - # unsqueeze |
31 | | - udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions] |
32 | | - |
33 | | - def vector_mask_function( |
34 | | - *args, mask_function=mask_function, dimensions=dimensions, indices=indices |
35 | | - ): |
36 | | - assert len(args) == len(dimensions) == len(udimensions), ( |
37 | | - f"Mismatch between args={string_type(args)} and dimensions={dimensions} " |
38 | | - f"and udimensions={udimensions}." |
39 | | - ) |
40 | | - assert len(indices) == len(args), ( |
41 | | - f"Mismatch between args={string_type(args)} and indices={indices}, " |
42 | | - f"they should have the same length." |
| 22 | +if patch_masking_utils: |
| 23 | + # Introduced in 4.52 |
| 24 | + from transformers.masking_utils import causal_mask_function, sdpa_mask |
| 25 | + |
| 26 | + def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: |
| 27 | + """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" |
| 28 | + from ...helpers import string_type |
| 29 | + |
| 30 | + dimensions: List[Tuple[Optional[int], ...]] = [ |
| 31 | + (None, None, None, 0), |
| 32 | + (None, None, 0, None), |
| 33 | + ] |
| 34 | + if bh_indices: |
| 35 | + dimensions.extend([(None, 0, None, None), (0, None, None, None)]) |
| 36 | + # reshape |
| 37 | + dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions] |
| 38 | + dimensions = tuple(reversed(dimensions)) |
| 39 | + indices = tuple(shape.index(-1) for shape in dimensions) |
| 40 | + |
| 41 | + # unsqueeze |
| 42 | + udimensions = [ |
| 43 | + tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions |
| 44 | + ] |
| 45 | + |
| 46 | + def vector_mask_function( |
| 47 | + *args, mask_function=mask_function, dimensions=dimensions, indices=indices |
| 48 | + ): |
| 49 | + assert len(args) == len(dimensions) == len(udimensions), ( |
| 50 | + f"Mismatch between args={string_type(args)} and dimensions={dimensions} " |
| 51 | + f"and udimensions={udimensions}." |
| 52 | + ) |
| 53 | + assert len(indices) == len(args), ( |
| 54 | + f"Mismatch between args={string_type(args)} and indices={indices}, " |
| 55 | + f"they should have the same length." |
| 56 | + ) |
| 57 | + for a in args: |
| 58 | + assert ( |
| 59 | + a.ndim == 1 |
| 60 | + ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}" |
| 61 | + torch._check(a.shape[0] > 0) |
| 62 | + |
| 63 | + new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)] |
| 64 | + # new_args = [ |
| 65 | + # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2]) |
| 66 | + # for a, dims in zip(args, udimensions) |
| 67 | + # ] |
| 68 | + max_shape = tuple(args[i].shape[0] for i in indices) |
| 69 | + # if is_torchdynamo_exporting(): |
| 70 | + # for a in args: |
| 71 | + # # The exporter should export with a dimension > 1 |
| 72 | + # # to make sure it is dynamic. |
| 73 | + # torch._check(a.shape[0] > 1) |
| 74 | + expanded_args = [a.expand(max_shape) for a in new_args] |
| 75 | + return mask_function(*expanded_args) |
| 76 | + |
| 77 | + return vector_mask_function |
| 78 | + |
| 79 | + def patched_eager_mask( |
| 80 | + batch_size: int, |
| 81 | + cache_position: torch.Tensor, |
| 82 | + kv_length: int, |
| 83 | + kv_offset: int = 0, |
| 84 | + mask_function: Callable = causal_mask_function, |
| 85 | + attention_mask: Optional[torch.Tensor] = None, |
| 86 | + dtype: torch.dtype = torch.float32, |
| 87 | + **kwargs, |
| 88 | + ) -> torch.Tensor: |
| 89 | + """manual patch for function ``transformers.masking_utils.eager_mask``.""" |
| 90 | + # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf |
| 91 | + _ = kwargs.pop("allow_is_causal_skip", None) |
| 92 | + mask = sdpa_mask( |
| 93 | + batch_size=batch_size, |
| 94 | + cache_position=cache_position, |
| 95 | + kv_length=kv_length, |
| 96 | + kv_offset=kv_offset, |
| 97 | + mask_function=mask_function, |
| 98 | + attention_mask=attention_mask, |
| 99 | + allow_is_causal_skip=False, |
| 100 | + allow_torch_fix=False, |
| 101 | + **kwargs, |
43 | 102 | ) |
44 | | - for a in args: |
45 | | - assert ( |
46 | | - a.ndim == 1 |
47 | | - ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}" |
48 | | - torch._check(a.shape[0] > 0) |
49 | | - |
50 | | - new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)] |
51 | | - # new_args = [ |
52 | | - # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2]) |
53 | | - # for a, dims in zip(args, udimensions) |
54 | | - # ] |
55 | | - max_shape = tuple(args[i].shape[0] for i in indices) |
56 | | - # if is_torchdynamo_exporting(): |
57 | | - # for a in args: |
58 | | - # # The exporter should export with a dimension > 1 to make sure it is dynamic. |
59 | | - # torch._check(a.shape[0] > 1) |
60 | | - expanded_args = [a.expand(max_shape) for a in new_args] |
61 | | - return mask_function(*expanded_args) |
62 | | - |
63 | | - return vector_mask_function |
| 103 | + min_dtype = torch.finfo(dtype).min |
| 104 | + # The patched line. |
| 105 | + # we need 0s where the tokens should be taken into account, |
| 106 | + # and -inf otherwise (mask is already of boolean type) |
| 107 | + # mask = |
| 108 | + # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) |
| 109 | + mask = (~mask).to(dtype) * min_dtype |
| 110 | + return mask |
64 | 111 |
|
65 | 112 |
|
66 | 113 | def _patch_make_causal_mask( |
@@ -1047,36 +1094,3 @@ def forward( |
1047 | 1094 | attn_weights = None |
1048 | 1095 |
|
1049 | 1096 | return attn_output, attn_weights, past_key_value |
1050 | | - |
1051 | | - |
1052 | | -def patched_eager_mask( |
1053 | | - batch_size: int, |
1054 | | - cache_position: torch.Tensor, |
1055 | | - kv_length: int, |
1056 | | - kv_offset: int = 0, |
1057 | | - mask_function: Callable = causal_mask_function, |
1058 | | - attention_mask: Optional[torch.Tensor] = None, |
1059 | | - dtype: torch.dtype = torch.float32, |
1060 | | - **kwargs, |
1061 | | -) -> torch.Tensor: |
1062 | | - """manual patch for function ``transformers.masking_utils.eager_mask``.""" |
1063 | | - # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf |
1064 | | - _ = kwargs.pop("allow_is_causal_skip", None) |
1065 | | - mask = sdpa_mask( |
1066 | | - batch_size=batch_size, |
1067 | | - cache_position=cache_position, |
1068 | | - kv_length=kv_length, |
1069 | | - kv_offset=kv_offset, |
1070 | | - mask_function=mask_function, |
1071 | | - attention_mask=attention_mask, |
1072 | | - allow_is_causal_skip=False, |
1073 | | - allow_torch_fix=False, |
1074 | | - **kwargs, |
1075 | | - ) |
1076 | | - min_dtype = torch.finfo(dtype).min |
1077 | | - # The patched line. |
1078 | | - # we need 0s where the tokens should be taken into account, |
1079 | | - # and -inf otherwise (mask is already of boolean type) |
1080 | | - # mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) |
1081 | | - mask = (~mask).to(dtype) * min_dtype |
1082 | | - return mask |
|
0 commit comments