Skip to content

Commit 905201f

Browse files
committed
split patches into multiple files
1 parent 0063681 commit 905201f

16 files changed

+2728
-2610
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
from typing import Optional
2+
import torch
3+
import transformers
4+
from .patch_helper import _has_transformers
5+
6+
patch_sdpa_is_causal = _has_transformers("4.99")
7+
8+
9+
def common_eager_attention_forward(
10+
module: torch.nn.Module,
11+
query: torch.Tensor,
12+
key: torch.Tensor,
13+
value: torch.Tensor,
14+
attention_mask: Optional[torch.Tensor],
15+
scaling: Optional[float] = None,
16+
dropout: float = 0.0,
17+
head_mask: Optional[torch.Tensor] = None,
18+
**kwargs,
19+
):
20+
if scaling is None:
21+
scaling = query.size(-1) ** -0.5
22+
23+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
24+
if attention_mask is not None:
25+
# PATCHED
26+
# The two following lines were added.
27+
if attention_mask is not None and attention_mask.ndim == 4:
28+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
29+
attn_weights = attn_weights + attention_mask
30+
31+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
32+
33+
if head_mask is not None:
34+
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
35+
36+
attn_weights = torch.nn.functional.dropout(
37+
attn_weights, p=dropout, training=module.training
38+
)
39+
attn_output = torch.matmul(attn_weights, value)
40+
attn_output = attn_output.transpose(1, 2).contiguous()
41+
42+
return attn_output, attn_weights
43+
44+
45+
def patched_sdpa_attention_forward(
46+
module: torch.nn.Module,
47+
query: torch.Tensor,
48+
key: torch.Tensor,
49+
value: torch.Tensor,
50+
attention_mask: Optional[torch.Tensor],
51+
dropout: float = 0.0,
52+
scaling: Optional[float] = None,
53+
is_causal: Optional[bool] = None,
54+
**kwargs,
55+
) -> tuple[torch.Tensor, None]:
56+
"""
57+
manual patch for function
58+
``transformers.integrations.sdpa_attention.sdpa_attention_forward``
59+
"""
60+
assert not kwargs.get("output_attentions", False), (
61+
"`sdpa` attention does not support `output_attentions=True`."
62+
" Please set your attention to `eager` if you want any of these features."
63+
)
64+
torch._check(
65+
query.shape[0] == key.shape[0] or query.shape[0] == 1,
66+
lambda: (
67+
f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
68+
f"value: {value.shape}"
69+
),
70+
)
71+
torch._check(
72+
key.shape[0] == value.shape[0] or key.shape[0] == 1,
73+
lambda: (
74+
f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
75+
f"value: {value.shape}"
76+
),
77+
)
78+
79+
sdpa_kwargs = {}
80+
if hasattr(module, "num_key_value_groups"):
81+
if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
82+
key = transformers.integrations.sdpa_attention.repeat_kv(
83+
key, module.num_key_value_groups
84+
)
85+
value = transformers.integrations.sdpa_attention.repeat_kv(
86+
value, module.num_key_value_groups
87+
)
88+
else:
89+
sdpa_kwargs = {"enable_gqa": True}
90+
91+
if attention_mask is not None and attention_mask.ndim == 4:
92+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
93+
94+
torch._check(
95+
attention_mask is None or attention_mask.shape[3] == key.shape[2],
96+
lambda: "Attention mask shape incompatible with key shape.",
97+
)
98+
99+
if patch_sdpa_is_causal:
100+
# transformers>=4.55
101+
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
102+
103+
# PATCHED: remove the test query.shape[2] > 1
104+
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
105+
# and we split the test to keep the minimum in torch.cond
106+
is_causal = attention_mask is None and is_causal
107+
108+
if not is_causal:
109+
torch._check(query.shape[0] > 0)
110+
torch._check(query.shape[1] > 0)
111+
torch._check(query.shape[2] > 0)
112+
torch._check(query.shape[3] > 0)
113+
torch._check(key.shape[0] > 0)
114+
torch._check(key.shape[1] > 0)
115+
torch._check(key.shape[2] > 0)
116+
torch._check(key.shape[3] > 0)
117+
torch._check(value.shape[0] > 0)
118+
torch._check(value.shape[1] > 0)
119+
torch._check(value.shape[2] > 0)
120+
torch._check(value.shape[3] > 0)
121+
return (
122+
torch.nn.functional.scaled_dot_product_attention(
123+
query,
124+
key,
125+
value,
126+
attn_mask=attention_mask,
127+
dropout_p=dropout,
128+
scale=scaling,
129+
is_causal=is_causal,
130+
**sdpa_kwargs,
131+
)
132+
.transpose(1, 2)
133+
.contiguous(),
134+
None,
135+
)
136+
else:
137+
# transformers<4.55
138+
if is_causal is None and attention_mask is not None:
139+
is_causal = False
140+
if is_causal is not None:
141+
return (
142+
torch.nn.functional.scaled_dot_product_attention(
143+
query,
144+
key,
145+
value,
146+
attn_mask=attention_mask,
147+
dropout_p=dropout,
148+
scale=scaling,
149+
is_causal=is_causal,
150+
**sdpa_kwargs,
151+
)
152+
.transpose(1, 2)
153+
.contiguous(),
154+
None,
155+
)
156+
157+
# To avoid the following errors:
158+
# is_causal=query.shape[2] > 1
159+
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
160+
# is_causal=torch.tensor(query.shape[2] > 1)
161+
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
162+
attn_output = torch.cond(
163+
query.shape[2] > 1, # distinction between prefill and decoding steps
164+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
165+
query,
166+
key,
167+
value,
168+
dropout_p=dropout,
169+
scale=scaling,
170+
is_causal=True,
171+
**sdpa_kwargs,
172+
).contiguous(),
173+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
174+
query,
175+
key,
176+
value,
177+
dropout_p=dropout,
178+
scale=scaling,
179+
is_causal=False,
180+
**sdpa_kwargs,
181+
).contiguous(),
182+
[query, key, value],
183+
)
184+
attn_output = attn_output.transpose(1, 2).contiguous()
185+
return attn_output, None
186+
187+
188+
def patched_model_bart_eager_attention_forward(
189+
module: torch.nn.Module,
190+
query: torch.Tensor,
191+
key: torch.Tensor,
192+
value: torch.Tensor,
193+
attention_mask: Optional[torch.Tensor],
194+
scaling: Optional[float] = None,
195+
dropout: float = 0.0,
196+
head_mask: Optional[torch.Tensor] = None,
197+
**kwargs,
198+
):
199+
"""[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
200+
return common_eager_attention_forward(
201+
module,
202+
query,
203+
key,
204+
value,
205+
attention_mask=attention_mask,
206+
scaling=scaling,
207+
dropout=dropout,
208+
head_mask=head_mask,
209+
**kwargs,
210+
)
211+
212+
213+
def patched_modeling_marian_eager_attention_forward(
214+
module: torch.nn.Module,
215+
query: torch.Tensor,
216+
key: torch.Tensor,
217+
value: torch.Tensor,
218+
attention_mask: Optional[torch.Tensor],
219+
scaling: Optional[float] = None,
220+
dropout: float = 0.0,
221+
head_mask: Optional[torch.Tensor] = None,
222+
**kwargs,
223+
):
224+
"""[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
225+
return common_eager_attention_forward(
226+
module,
227+
query,
228+
key,
229+
value,
230+
attention_mask=attention_mask,
231+
scaling=scaling,
232+
dropout=dropout,
233+
head_mask=head_mask,
234+
**kwargs,
235+
)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Optional
2+
import inspect
3+
import transformers
4+
5+
try:
6+
from transformers.cache_utils import parse_processor_args # noqa: F401
7+
8+
patch_parse_processor_args = True
9+
except ImportError:
10+
patch_parse_processor_args = False
11+
12+
13+
if patch_parse_processor_args:
14+
15+
def _init_cache_inspect():
16+
res = {}
17+
for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
18+
try:
19+
params = list(inspect.signature(processor_class.__init__).parameters)[2:]
20+
res[processor_class.__init__] = params
21+
except Exception:
22+
res[processor_class.__init__] = None
23+
return res
24+
25+
_cache_inspect = _init_cache_inspect()
26+
27+
def patched_parse_processor_args(
28+
processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
29+
) -> tuple[dict, dict]:
30+
"""[patch:transformers.cache_utils.parse_processor_args]"""
31+
# If not patched...
32+
# Fails with transformers>=4.54 because function ``parse_processor_args``
33+
# relies in inspect and the exporter is not very fond of that.
34+
# torch._dynamo.exc.Unsupported: id() with unsupported args
35+
# Explanation: Dynamo doesn't know how to trace id()
36+
# call with args
37+
# (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
38+
# Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
39+
# objects from outside the compiled region.
40+
# Hint: It may be possible to write Dynamo tracing rules for this code.
41+
#
42+
# The patch is caching the signature to avoid any call to inspect.
43+
if processor_class is None:
44+
return {}, kwargs
45+
params = _cache_inspect[processor_class.__init__]
46+
if params is None:
47+
return {}, kwargs
48+
processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
49+
remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
50+
return processor_kwargs, remaining_kwargs
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
import torch
4+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
5+
from .patch_helper import _has_transformers
6+
7+
8+
def _patch_make_causal_mask(
9+
input_ids_shape: torch.Size,
10+
dtype: torch.dtype,
11+
device: torch.device,
12+
past_key_values_length: int = 0,
13+
sliding_window: Optional[int] = None,
14+
):
15+
"""Patched method."""
16+
bsz, tgt_len = input_ids_shape
17+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
18+
mask_cond = torch.arange(mask.size(-1), device=device)
19+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
20+
21+
mask = mask.to(dtype)
22+
23+
if past_key_values_length > 0:
24+
mask = torch.cat(
25+
[
26+
torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
27+
mask,
28+
],
29+
dim=-1,
30+
)
31+
32+
if sliding_window is not None:
33+
diagonal = past_key_values_length - sliding_window - 1
34+
35+
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
36+
# PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
37+
# and used masked_fill instead of masked_fill_
38+
# In this case, the current implementation of torch fails (17/12/2024).
39+
# Try model Phi-3.5-Mini-Instruct.
40+
mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
41+
42+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
43+
44+
45+
@dataclass
46+
class patched_AttentionMaskConverter:
47+
"""
48+
Patches
49+
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
50+
"""
51+
52+
# This method was fixed in 4.51 at least.
53+
_PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
54+
_PATCHED_CLASS_ = AttentionMaskConverter
55+
56+
@staticmethod
57+
def _make_causal_mask(
58+
*args,
59+
**kwargs,
60+
# input_ids_shape: torch.Size,
61+
# dtype: torch.dtype,
62+
# device: torch.device,
63+
# past_key_values_length: int = 0,
64+
# sliding_window: Optional[int] = None,
65+
):
66+
"""
67+
Patched method.
68+
69+
This static method may be called with ``AttentionMaskConverter._make_causal_mask``
70+
or ``self._make_causal_mask``. That changes this argument is receives.
71+
That should not matter but...
72+
The patch should be implemented in another way. static methods do not play well
73+
with a simple replacement.
74+
Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
75+
"""
76+
if args:
77+
index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
78+
names = [
79+
"input_ids_shape",
80+
"dtype",
81+
"device",
82+
"past_key_values_length",
83+
"sliding_window",
84+
]
85+
for i, a in enumerate(args):
86+
if i < index:
87+
continue
88+
kwargs[names[i - index]] = a
89+
return _patch_make_causal_mask(**kwargs)

0 commit comments

Comments
 (0)