Skip to content

Commit 2523b0d

Browse files
committed
fix issues
1 parent 2b92218 commit 2523b0d

File tree

6 files changed

+246
-27
lines changed

6 files changed

+246
-27
lines changed

_doc/technical/plot_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def simple_generate_with_cache(
155155
dtype = get_weight_type(model)
156156
print("-- model dtype:", dtype)
157157
export_inputs["past_key_values"] = to_any(export_inputs["past_key_values"], dtype)
158-
exporter = "custom" if "custom" in sys.argv else "onnx-dynamo"
158+
exporter = "onnx-dynamo" if "dynamo" in sys.argv else "custom"
159159
model_name = f"model_{model_id.replace('/', '-')}.{exporter}.onnx"
160160
if not os.path.exists(model_name):
161161
# This step is slow so let's skip it if it was already done.

_unittests/ut_export/test_api.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
44
from onnx_diagnostic.helpers import max_diff
55
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
66
from onnx_diagnostic.helpers.rt_helper import make_feeds
7+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
78
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
89
from onnx_diagnostic.torch_export_patches import torch_export_patches
910
from onnx_diagnostic.export.api import to_onnx
@@ -46,6 +47,10 @@ def test_tiny_llm_to_onnx(self):
4647
"onnx-dynamo": self.get_dump_file("test_tiny_llm_to_onnx-dynamo.onnx"),
4748
"modelbuilder": self.get_dump_file("model.onnx"),
4849
}
50+
if not has_transformers("4.55"):
51+
# <4.55: torch._check(causal_mask.shape[3] != 33)
52+
# torch._check(causal_mask.shape[3] == 33)
53+
del filenames["onnx-dynamo"]
4954
del inputs["position_ids"]
5055
del ds["position_ids"]
5156
del b1["position_ids"]
@@ -72,14 +77,24 @@ def test_tiny_llm_to_onnx(self):
7277
diff = max_diff(expected, got)
7378
assert diff["abs"] <= 1e-5, f"diff={diff}"
7479

75-
b1["attention_mask"][:, :] = 1
76-
expected = model(**torch_deepcopy(b1))
80+
problem = dict(
81+
input_ids=torch.tensor([[24320]], dtype=torch.int64),
82+
attention_mask=torch.tensor([[1, 1, 1, 1]], dtype=torch.int64),
83+
past_key_values=make_dynamic_cache(
84+
[
85+
torch.rand((1, 1, 3, 96), dtype=torch.float32),
86+
torch.rand((1, 1, 3, 96), dtype=torch.float32),
87+
]
88+
),
89+
)
90+
91+
expected = model(**torch_deepcopy(problem))
7792
for exporter, filename in filenames.items():
7893
with self.subTest(exporter=f"full-mask-{exporter}"):
7994
sess = onnxruntime.InferenceSession(
8095
filename, providers=["CPUExecutionProvider"]
8196
)
82-
feeds = make_feeds(sess, b1, use_numpy=True)
97+
feeds = make_feeds(sess, problem, use_numpy=True)
8398
got = sess.run(None, feeds)
8499
diff = max_diff(expected, got)
85100
assert diff["abs"] <= 1e-5, f"diff={diff}"

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def simple_generate_with_cache(
4848
f"\ninput_ids.shape={input_ids.shape}"
4949
f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}"
5050
f"\n got=\n"
51-
f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}"
51+
f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
52+
f"feeds={self.string_type(feeds, with_shape=True, with_min_max=True)}"
5253
)
5354

5455
# Next calls: decode
@@ -87,7 +88,8 @@ def simple_generate_with_cache(
8788
f"\ndiff={diff}\ninput_ids.shape={input_ids.shape}"
8889
f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}"
8990
f"\n got=\n"
90-
f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}"
91+
f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
92+
f"feeds={self.string_type(feeds, with_shape=True, with_min_max=True)}"
9193
)
9294
return input_ids
9395

@@ -113,7 +115,7 @@ def test_onnx_generate(self):
113115
kwargs=inputs,
114116
dynamic_shapes=ds,
115117
filename=model_name,
116-
exporter="modelbuilder",
118+
exporter="custom",
117119
)
118120

119121
print("-- test_onnx_generate: generate")
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import unittest
2+
import torch
3+
import transformers
4+
import transformers.integrations.sdpa_attention as sdpa_attention
5+
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers
6+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
7+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
8+
9+
10+
class TestPatchPatchTransformers(ExtTestCase):
11+
@requires_transformers("4.55")
12+
def test_sdpa_mask_recent_torch(self):
13+
sdpa_mask_recent_torch = transformers.masking_utils.sdpa_mask_recent_torch
14+
patched_sdpa_mask_recent_torch = patch_transformers.patched_sdpa_mask_recent_torch
15+
kwargs = {
16+
"batch_size": 1,
17+
"cache_position": torch.tensor([3], dtype=torch.int64),
18+
"kv_length": 4,
19+
"kv_offset": 0,
20+
"mask_function": transformers.masking_utils.causal_mask_function,
21+
"attention_mask": torch.tensor([[True, True, True, True]]),
22+
"local_size": None,
23+
"allow_is_causal_skip": True,
24+
"allow_is_bidirectional_skip": False,
25+
}
26+
expected = sdpa_mask_recent_torch(**kwargs)
27+
got = patched_sdpa_mask_recent_torch(**kwargs)
28+
self.assertEqual(expected, got)
29+
30+
kwargs = {
31+
"batch_size": 1,
32+
"cache_position": torch.tensor([3], dtype=torch.int64),
33+
"kv_length": 4,
34+
"kv_offset": 0,
35+
"mask_function": transformers.masking_utils.causal_mask_function,
36+
"attention_mask": torch.tensor([[True, True, True, True]]),
37+
"local_size": None,
38+
"allow_is_causal_skip": False,
39+
"allow_is_bidirectional_skip": False,
40+
}
41+
expected = sdpa_mask_recent_torch(**kwargs)
42+
got = patched_sdpa_mask_recent_torch(**kwargs)
43+
self.assertEqualArray(expected, got)
44+
45+
@requires_transformers("4.55")
46+
def test_sdpa_attention_forward_not_causal(self):
47+
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
48+
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward
49+
kwargs = {
50+
"module": None,
51+
"query": torch.rand((1, 2, 1, 96), dtype=torch.float32),
52+
"key": torch.rand((1, 2, 4, 96), dtype=torch.float32),
53+
"value": torch.rand((1, 2, 4, 96), dtype=torch.float32),
54+
"attention_mask": None,
55+
"attention_dropout": 0,
56+
"scaling": 0.10206207261596575,
57+
"is_causal": False,
58+
}
59+
expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
60+
got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
61+
self.assertEqualArray(expected, got)
62+
63+
kwargs = {
64+
"module": None,
65+
"query": torch.rand((1, 2, 1, 96), dtype=torch.float32),
66+
"key": torch.rand((1, 2, 4, 96), dtype=torch.float32),
67+
"value": torch.rand((1, 2, 4, 96), dtype=torch.float32),
68+
"attention_mask": torch.tensor([[[[True, True, True, True]]]]),
69+
"attention_dropout": 0,
70+
"scaling": 0.10206207261596575,
71+
"is_causal": False,
72+
}
73+
expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
74+
got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
75+
self.assertEqualArray(expected, got)
76+
77+
@requires_transformers("4.55")
78+
def test_sdpa_attention_forward_causal(self):
79+
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
80+
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward
81+
kwargs = {
82+
"module": None,
83+
"query": torch.rand((1, 2, 1, 96), dtype=torch.float32),
84+
"key": torch.rand((1, 2, 4, 96), dtype=torch.float32),
85+
"value": torch.rand((1, 2, 4, 96), dtype=torch.float32),
86+
"attention_mask": torch.tensor([[[[True, True, True, True]]]]),
87+
"attention_dropout": 0,
88+
"scaling": 0.10206207261596575,
89+
"is_causal": True,
90+
}
91+
expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
92+
got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
93+
self.assertEqualArray(expected, got)
94+
95+
kwargs = {
96+
"module": None,
97+
"query": torch.rand((1, 2, 1, 96), dtype=torch.float32),
98+
"key": torch.rand((1, 2, 4, 96), dtype=torch.float32),
99+
"value": torch.rand((1, 2, 4, 96), dtype=torch.float32),
100+
"attention_mask": None,
101+
"attention_dropout": 0,
102+
"scaling": 0.10206207261596575,
103+
"is_causal": True,
104+
}
105+
expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
106+
got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
107+
self.assertEqualArray(expected, got)
108+
109+
def test_causal_mask_in_scaled_dot_product_attention(self):
110+
# see https://docs.pytorch.org/docs/stable/generated/...
111+
# ...torch.nn.functional.scaled_dot_product_attention.html
112+
113+
query = torch.rand((1, 2, 1, 96), dtype=torch.float32)
114+
key = torch.rand((1, 2, 4, 96), dtype=torch.float32)
115+
L, S = query.size(-2), key.size(-2)
116+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
117+
self.assertEqual(attn_bias.min().item(), 0)
118+
attn_causal_bias = attn_bias.clone()
119+
120+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
121+
attn_causal_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
122+
self.assertEqual(attn_causal_bias.min().item(), -float("inf"))
123+
124+
125+
if __name__ == "__main__":
126+
unittest.main(verbosity=2)

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,9 +856,15 @@ def torch_deepcopy(value: Any) -> Any:
856856
), f"Unexpected type={type(value)}"
857857
return copy.deepcopy(value)
858858

859+
if hasattr(value, "__nocopy__"):
860+
return value
861+
859862
# We should have a code using serialization, deserialization assuming a model
860863
# cannot be exported without them.
861-
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
864+
raise NotImplementedError(
865+
f"torch_deepcopy not implemented for type {type(value)}, "
866+
f"add attribute '__nocopy__' to return it as is."
867+
)
862868

863869

864870
def torch_tensor_size(value: Any) -> Any:

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 88 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,45 @@
3939
except ImportError:
4040
patch_DynamicLayer = False
4141

42-
from ...ext_test_case import has_transformers
43-
from ...helpers.torch_helper import is_torchdynamo_exporting
4442

45-
patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99")
43+
def _has_transformers(version: str) -> bool:
44+
return pv.Version(transformers.__version__) >= pv.Version(version)
45+
46+
47+
def _is_torchdynamo_exporting() -> bool:
48+
"""
49+
Tells if :epkg:`torch` is exporting a model.
50+
Relies on ``torch.compiler.is_exporting()``.
51+
"""
52+
import torch
53+
54+
if not hasattr(torch.compiler, "is_exporting"):
55+
# torch.compiler.is_exporting requires torch>=2.7
56+
return False
57+
58+
try:
59+
return torch.compiler.is_exporting()
60+
except Exception:
61+
try:
62+
import torch._dynamo as dynamo
63+
64+
return dynamo.is_exporting() # type: ignore
65+
except Exception:
66+
return False
67+
68+
69+
patch_is_initialized = _has_transformers("4.56.99")
4670

4771

4872
if patch_masking_utils:
4973
# Introduced in 4.52
5074
from transformers.masking_utils import (
75+
_ignore_causal_mask_sdpa,
76+
_ignore_bidirectional_mask_sdpa,
77+
and_masks,
78+
bidirectional_mask_function,
5179
causal_mask_function,
5280
padding_mask_function,
53-
and_masks,
54-
_ignore_causal_mask_sdpa,
5581
prepare_padding_mask,
5682
)
5783

@@ -98,7 +124,7 @@ def vector_mask_function(
98124
# for a, dims in zip(args, udimensions)
99125
# ]
100126
max_shape = tuple(args[i].shape[0] for i in indices)
101-
# if is_torchdynamo_exporting():
127+
# if _is_torchdynamo_exporting():
102128
# for a in args:
103129
# # The exporter should export with a dimension > 1
104130
# # to make sure it is dynamic.
@@ -151,6 +177,7 @@ def patched_sdpa_mask_recent_torch(
151177
attention_mask: Optional[torch.Tensor] = None,
152178
local_size: Optional[int] = None,
153179
allow_is_causal_skip: bool = True,
180+
allow_is_bidirectional_skip: bool = False,
154181
**kwargs,
155182
) -> Optional[torch.Tensor]:
156183
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
@@ -160,6 +187,25 @@ def patched_sdpa_mask_recent_torch(
160187
padding_mask, q_length, kv_length, kv_offset, local_size
161188
):
162189
return None
190+
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
191+
return None
192+
193+
if mask_function is bidirectional_mask_function:
194+
if padding_mask is not None:
195+
# used for slicing without data-dependent slicing
196+
mask_indices = (
197+
torch.arange(kv_length, device=cache_position.device) + kv_offset
198+
)
199+
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
200+
return torch.ones(
201+
batch_size,
202+
1,
203+
q_length,
204+
kv_length,
205+
dtype=torch.bool,
206+
device=cache_position.device,
207+
)
208+
163209
kv_arange = torch.arange(kv_length, device=cache_position.device)
164210
kv_arange += kv_offset
165211
if padding_mask is not None:
@@ -275,7 +321,7 @@ class patched_AttentionMaskConverter:
275321
"""
276322

277323
# This method was fixed in 4.51 at least.
278-
_PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else []
324+
_PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
279325
_PATCHED_CLASS_ = AttentionMaskConverter
280326

281327
@staticmethod
@@ -507,7 +553,7 @@ def _cache_dependant_input_preparation(
507553
The current implementation does not rely on ``self`` and could be
508554
a class method. It is left as a standard method to be easily rewritten.
509555
"""
510-
if is_torchdynamo_exporting():
556+
if _is_torchdynamo_exporting():
511557
return self._cache_dependant_input_preparation_exporting(
512558
input_ids, inputs_embeds, cache_position
513559
)
@@ -1316,16 +1362,40 @@ def patched_sdpa_attention_forward(
13161362
attention_mask is None or attention_mask.shape[3] == key.shape[2],
13171363
"Attention mask shape incompatible with key shape.",
13181364
)
1319-
attn_output = torch.nn.functional.scaled_dot_product_attention(
1320-
query,
1321-
key,
1322-
value,
1323-
attn_mask=attention_mask,
1324-
dropout_p=dropout,
1325-
scale=scaling,
1326-
is_causal=is_causal,
1327-
**sdpa_kwargs,
1328-
)
1365+
if is_causal:
1366+
attn_output = torch.cond(
1367+
query.shape[2] > 1, # distinction between prefill and decoding steps
1368+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1369+
query,
1370+
key,
1371+
value,
1372+
dropout_p=dropout,
1373+
scale=scaling,
1374+
is_causal=True,
1375+
**sdpa_kwargs,
1376+
),
1377+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1378+
query,
1379+
key,
1380+
value,
1381+
dropout_p=dropout,
1382+
scale=scaling,
1383+
is_causal=False,
1384+
**sdpa_kwargs,
1385+
),
1386+
[query, key, value],
1387+
)
1388+
else:
1389+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1390+
query,
1391+
key,
1392+
value,
1393+
attn_mask=attention_mask,
1394+
dropout_p=dropout,
1395+
scale=scaling,
1396+
is_causal=is_causal,
1397+
**sdpa_kwargs,
1398+
)
13291399
attn_output = attn_output.transpose(1, 2).contiguous()
13301400
return attn_output, None
13311401

0 commit comments

Comments
 (0)