Skip to content

Commit ab5832b

Browse files
authored
Improves patch for attention, fix make_fake_with_dynamic_shapes (#280)
* improve patch for attention * fix make_fake * rewriting * unit * mypy * mypy * fix documentation * fix documentation * fix patch for other version of transformers * fix patch * fix * rename * doc * spell * improve * split files
1 parent c61e539 commit ab5832b

File tree

14 files changed

+653
-304
lines changed

14 files changed

+653
-304
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.0
55
+++++
66

7+
* :pr:`280`: fixes patches for sdpa_attention_forward for different version of transformers
78
* :pr:`278`: implements ``onnx_generate_with_genai``
89
* :pr:`277`: changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)
910
* :pr:`276`: implements ``onnx_generate`` which implements method generate for an onnx model,

_doc/recipes/plot_dynamic_shapes_json.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from onnx_diagnostic.helpers import string_type
2323
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
2424
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
25+
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions
2526

2627
bsize, nheads, slen, dim = 2, 1, 30, 96
2728

@@ -41,7 +42,11 @@
4142
# %%
4243
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`
4344
# produces the corresponding dynamic shapes assuming they are all dynamic.
44-
ds = all_dynamic_shapes_from_inputs(inputs)
45+
# ``register_additional_serialization_functions(patch_transformers=True)`` registers
46+
# function letting pytorch know how to serialize, deserialize the class DynamicCache.
47+
48+
with register_additional_serialization_functions(patch_transformers=True):
49+
ds = all_dynamic_shapes_from_inputs(inputs)
4550
pprint.pprint(ds)
4651

4752
# %%
@@ -88,7 +93,8 @@ def flatten_unflatten_like_dynamic_shapes(obj):
8893
return list(subtrees)
8994
raise ValueError(
9095
f"Unable to interpret spec type {spec.type} "
91-
f"(type is {type(spec.type)}, context is {spec.context})."
96+
f"(type is {type(spec.type)}, context is {spec.context}), "
97+
f"obj type is {type(obj)}."
9298
)
9399

94100

@@ -109,7 +115,8 @@ def fix_dynamic_shapes(inputs, dynamic_shapes):
109115
return _align(flat_unflat_inputs, dynamic_shapes)
110116

111117

112-
fixed_ds = fix_dynamic_shapes(inputs, ds2)
118+
with register_additional_serialization_functions(patch_transformers=True):
119+
fixed_ds = fix_dynamic_shapes(inputs, ds2)
113120
pprint.pprint(fixed_ds)
114121

115122
# %%

_doc/recipes/plot_export_dim1.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
import torch
1515
from onnx_diagnostic import doc
16+
from onnx_diagnostic.helpers import string_type
17+
from onnx_diagnostic.torch_export_patches import torch_export_patches
1618

1719

1820
class Model(torch.nn.Module):
@@ -29,21 +31,28 @@ def forward(self, x, y, z):
2931
DYN = torch.export.Dim.DYNAMIC
3032
ds = {0: DYN, 1: DYN}
3133

34+
print("-- export shape:", string_type((x, y, z), with_shape=True))
35+
print("-- dynamic shapes:", string_type((ds, ds, ds)))
36+
3237
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
33-
print(ep.graph)
38+
print(ep)
3439

3540
# %%
3641
# Same model, a dynamic dimension = 1
3742
# +++++++++++++++++++++++++++++++++++
3843

44+
3945
z = z[:1]
4046

4147
DYN = torch.export.Dim.DYNAMIC
4248
ds = {0: DYN, 1: DYN}
4349

50+
print("-- export shape:", string_type((x, y, z), with_shape=True))
51+
print("-- dynamic shapes:", string_type((ds, ds, ds)))
52+
4453
try:
4554
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
46-
print(ep.graph)
55+
print(ep)
4756
except Exception as e:
4857
print("ERROR", e)
4958

@@ -54,14 +63,33 @@ def forward(self, x, y, z):
5463
# Same model, a dynamic dimension = 1 and backed_size_oblivious=True
5564
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
5665

66+
print("-- export shape:", string_type((x, y, z), with_shape=True))
67+
print("-- dynamic shapes:", string_type((ds, ds, ds)))
68+
5769
try:
5870
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
5971
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
60-
print(ep.graph)
72+
print(ep)
6173
except RuntimeError as e:
6274
print("ERROR", e)
6375

76+
77+
# %%
78+
# Final try with patches...
79+
# +++++++++++++++++++++++++
80+
81+
print("-- export shape:", string_type((x, y, z), with_shape=True))
82+
print("-- dynamic shapes:", string_type((ds, ds, ds)))
83+
84+
with torch_export_patches(patch_torch=1):
85+
try:
86+
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
87+
print(ep)
88+
except RuntimeError as e:
89+
print("ERROR", e)
90+
6491
# %%
65-
# It worked.
92+
# It is difficult to find the good option. It is possible on a simple model
93+
# but sometimes impossible on a bigger model mixing different shapes.
6694

6795
doc.plot_legend("dynamic dimension\nworking with\n0 or 1", "torch.export.export", "green")

_unittests/ut_export/test_shape_helper.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,40 @@ def test_make_fake_with_dynamic_dimensions_tensor(self):
183183
self.assertEqual(reshaped.shape[3], 96)
184184
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
185185

186+
@requires_transformers("4.55")
187+
@requires_torch("2.9")
188+
def test_make_fake_with_dynamic_dimensions_two_tensors(self):
189+
res = make_fake_with_dynamic_dimensions(
190+
(
191+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
192+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
193+
),
194+
({0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}),
195+
)
196+
reshaped = res[0][0]
197+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
198+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
199+
self.assertEqual(reshaped.shape[1], 32)
200+
self.assertEqual(reshaped.shape[3], 96)
201+
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
202+
self.assertEqual(str(res[0][0].shape), str(res[0][1].shape))
203+
sh1 = res[0][0].shape
204+
sh2 = res[0][1].shape
205+
self.assertEqual(sh1[0], sh2[0])
206+
self.assertEqual(sh1[1], sh2[1])
207+
self.assertEqual(sh1[2], sh2[2])
208+
self.assertEqual(sh1[3], sh2[3])
209+
210+
def test_make_fake_with_dynamic_dimensions_attention(self):
211+
query = torch.rand((1, 2, 1, 96), dtype=torch.float32)
212+
key = torch.rand((1, 2, 4, 96), dtype=torch.float32)
213+
value = torch.rand((1, 2, 4, 96), dtype=torch.float32)
214+
ds = ({0: "batch", 2: "seq1"}, {0: "batch", 2: "seq2"}, {0: "batch", 2: "seq2"})
215+
fake_inputs, _ = make_fake_with_dynamic_dimensions((query, key, value), ds)
216+
self.assertEqual(fake_inputs[1].shape, fake_inputs[2].shape)
217+
self.assertEqual(fake_inputs[0].shape[0], fake_inputs[1].shape[0])
218+
self.assertEqual(fake_inputs[0].shape[0], fake_inputs[2].shape[0])
219+
186220
@requires_transformers("4.55")
187221
@requires_torch("2.9")
188222
def test_make_fake_with_dynamic_dimensions_whole(self):

_unittests/ut_helpers/test_fake_tensor_helper.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,72 +3,72 @@
33
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
44
from onnx_diagnostic.helpers import flatten_object
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
6-
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, fake_reshape
6+
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, FakeTensorContext
77

88

99
class TestMakeTensorHelper(ExtTestCase):
1010

11+
@requires_transformers("4.55")
12+
def test_fake_inputs(self):
13+
inputs, _ = make_fake(
14+
dict(
15+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
16+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
17+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
18+
past_key_values=make_dynamic_cache(
19+
[
20+
(
21+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
22+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
23+
),
24+
(
25+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
26+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
27+
),
28+
]
29+
),
30+
)
31+
)
32+
flat = flatten_object(inputs, drop_keys=True)
33+
for t in flat:
34+
self.assertIsInstance(t, torch.Tensor)
35+
assert all(
36+
isinstance(s, torch.SymInt) for s in t.shape
37+
), f"Wrong type {[type(s) for s in t.shape]} in {t.shape}"
38+
1139
def test_fake_reshape_generic(self):
1240
t = torch.zeros((2, 3, 4, 5), dtype=torch.float32)
13-
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
41+
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
1442
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
1543
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
1644
self.assertEqual(reshaped.shape[1], 3)
1745
self.assertEqual(reshaped.shape[3], 5)
1846

1947
def test_fake_reshape_dim_1(self):
2048
t = torch.zeros((1, 3, 4, 5), dtype=torch.float32)
21-
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
49+
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
2250
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
2351
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
2452
self.assertEqual(reshaped.shape[1], 3)
2553
self.assertEqual(reshaped.shape[3], 5)
2654

2755
def test_fake_reshape_dim_0(self):
2856
t = torch.zeros((0, 3, 4, 5), dtype=torch.float32)
29-
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
57+
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
3058
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
3159
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
3260
self.assertEqual(reshaped.shape[1], 3)
3361
self.assertEqual(reshaped.shape[3], 5)
3462

3563
def test_fake_reshape_different(self):
3664
t = torch.zeros((2, 3, 2, 5), dtype=torch.float32)
37-
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
65+
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
3866
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
3967
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
4068
self.assertEqual(reshaped.shape[1], 3)
4169
self.assertEqual(reshaped.shape[3], 5)
4270
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
4371

44-
@requires_transformers("4.55")
45-
def test_fake_inputs(self):
46-
inputs, _ = make_fake(
47-
dict(
48-
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
49-
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
50-
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
51-
past_key_values=make_dynamic_cache(
52-
[
53-
(
54-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
55-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
56-
),
57-
(
58-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
59-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
60-
),
61-
]
62-
),
63-
)
64-
)
65-
flat = flatten_object(inputs, drop_keys=True)
66-
for t in flat:
67-
self.assertIsInstance(t, torch.Tensor)
68-
assert all(
69-
isinstance(s, torch.SymInt) for s in t.shape
70-
), f"Wrong type {[type(s) for s in t.shape]} in {t.shape}"
71-
7272

7373
if __name__ == "__main__":
7474
unittest.main(verbosity=2)

_unittests/ut_tasks/test_tasks_text_generation.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_text_generation_phi_3_mini_128k_instruct(self):
4848

4949
@hide_stdout()
5050
@requires_transformers("4.53")
51-
@requires_torch("2.7.99")
51+
@requires_torch("2.8.99") # check_guards not supported
5252
def test_text_generation_tiny_llm(self):
5353
mid = "arnir0/Tiny-LLM"
5454
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
@@ -58,12 +58,20 @@ def test_text_generation_tiny_llm(self):
5858
expected = model(**torch_deepcopy(inputs))
5959
model(**data["inputs2"])
6060
fake = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes=ds)[0]
61-
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
61+
with torch_export_patches(patch_transformers=True, verbose=1, patch_torch=False):
6262
ep = torch.export.export(
6363
model, (), kwargs=fake, dynamic_shapes=use_dyn_not_str(ds), strict=False
6464
)
6565
# print(ep)
66-
got = ep.module()(**inputs_copied)
66+
rem = []
67+
for node in ep.graph.nodes:
68+
if "_assert" in str(node.target):
69+
rem.append(node)
70+
for node in rem:
71+
ep.graph.erase_node(node)
72+
ep.graph.lint()
73+
mod = ep.module(check_guards=False)
74+
got = mod(**inputs_copied)
6775
self.assertEqualAny(expected.past_key_values, got.past_key_values)
6876
self.assertEqualArray(expected.logits, got.logits)
6977

0 commit comments

Comments
 (0)