Skip to content

Commit 25da3e9

Browse files
committed
fix make_fake
1 parent 3323490 commit 25da3e9

File tree

6 files changed

+329
-210
lines changed

6 files changed

+329
-210
lines changed

_unittests/ut_export/test_shape_helper.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,40 @@ def test_make_fake_with_dynamic_dimensions_tensor(self):
183183
self.assertEqual(reshaped.shape[3], 96)
184184
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
185185

186+
@requires_transformers("4.55")
187+
@requires_torch("2.9")
188+
def test_make_fake_with_dynamic_dimensions_two_tensors(self):
189+
res = make_fake_with_dynamic_dimensions(
190+
(
191+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
192+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
193+
),
194+
({0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}),
195+
)
196+
reshaped = res[0][0]
197+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
198+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
199+
self.assertEqual(reshaped.shape[1], 32)
200+
self.assertEqual(reshaped.shape[3], 96)
201+
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
202+
self.assertEqual(str(res[0][0].shape), str(res[0][1].shape))
203+
sh1 = res[0][0].shape
204+
sh2 = res[0][1].shape
205+
self.assertEqual(sh1[0], sh2[0])
206+
self.assertEqual(sh1[1], sh2[1])
207+
self.assertEqual(sh1[2], sh2[2])
208+
self.assertEqual(sh1[3], sh2[3])
209+
210+
def test_make_fake_with_dynamic_dimensions_attention(self):
211+
query = torch.rand((1, 2, 1, 96), dtype=torch.float32)
212+
key = torch.rand((1, 2, 4, 96), dtype=torch.float32)
213+
value = torch.rand((1, 2, 4, 96), dtype=torch.float32)
214+
ds = ({0: "batch", 2: "seq1"}, {0: "batch", 2: "seq2"}, {0: "batch", 2: "seq2"})
215+
fake_inputs, _ = make_fake_with_dynamic_dimensions((query, key, value), ds)
216+
self.assertEqual(fake_inputs[1].shape, fake_inputs[2].shape)
217+
self.assertEqual(fake_inputs[0].shape[0], fake_inputs[1].shape[0])
218+
self.assertEqual(fake_inputs[0].shape[0], fake_inputs[2].shape[0])
219+
186220
@requires_transformers("4.55")
187221
@requires_torch("2.9")
188222
def test_make_fake_with_dynamic_dimensions_whole(self):

_unittests/ut_helpers/test_fake_tensor_helper.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,72 +3,72 @@
33
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
44
from onnx_diagnostic.helpers import flatten_object
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
6-
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, fake_reshape
6+
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, FakeTensorContext
77

88

99
class TestMakeTensorHelper(ExtTestCase):
1010

11+
@requires_transformers("4.55")
12+
def test_fake_inputs(self):
13+
inputs, _ = make_fake(
14+
dict(
15+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
16+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
17+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
18+
past_key_values=make_dynamic_cache(
19+
[
20+
(
21+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
22+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
23+
),
24+
(
25+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
26+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
27+
),
28+
]
29+
),
30+
)
31+
)
32+
flat = flatten_object(inputs, drop_keys=True)
33+
for t in flat:
34+
self.assertIsInstance(t, torch.Tensor)
35+
assert all(
36+
isinstance(s, torch.SymInt) for s in t.shape
37+
), f"Wrong type {[type(s) for s in t.shape]} in {t.shape}"
38+
1139
def test_fake_reshape_generic(self):
1240
t = torch.zeros((2, 3, 4, 5), dtype=torch.float32)
13-
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
41+
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
1442
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
1543
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
1644
self.assertEqual(reshaped.shape[1], 3)
1745
self.assertEqual(reshaped.shape[3], 5)
1846

1947
def test_fake_reshape_dim_1(self):
2048
t = torch.zeros((1, 3, 4, 5), dtype=torch.float32)
21-
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
49+
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
2250
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
2351
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
2452
self.assertEqual(reshaped.shape[1], 3)
2553
self.assertEqual(reshaped.shape[3], 5)
2654

2755
def test_fake_reshape_dim_0(self):
2856
t = torch.zeros((0, 3, 4, 5), dtype=torch.float32)
29-
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
57+
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
3058
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
3159
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
3260
self.assertEqual(reshaped.shape[1], 3)
3361
self.assertEqual(reshaped.shape[3], 5)
3462

3563
def test_fake_reshape_different(self):
3664
t = torch.zeros((2, 3, 2, 5), dtype=torch.float32)
37-
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
65+
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
3866
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
3967
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
4068
self.assertEqual(reshaped.shape[1], 3)
4169
self.assertEqual(reshaped.shape[3], 5)
4270
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
4371

44-
@requires_transformers("4.55")
45-
def test_fake_inputs(self):
46-
inputs, _ = make_fake(
47-
dict(
48-
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
49-
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
50-
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
51-
past_key_values=make_dynamic_cache(
52-
[
53-
(
54-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
55-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
56-
),
57-
(
58-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
59-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
60-
),
61-
]
62-
),
63-
)
64-
)
65-
flat = flatten_object(inputs, drop_keys=True)
66-
for t in flat:
67-
self.assertIsInstance(t, torch.Tensor)
68-
assert all(
69-
isinstance(s, torch.SymInt) for s in t.shape
70-
), f"Wrong type {[type(s) for s in t.shape]} in {t.shape}"
71-
7272

7373
if __name__ == "__main__":
7474
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,7 @@ def forward(self, query, key, value):
168168
# dynamic
169169
ds = ({0: "batch", 2: "seq1"}, {0: "batch", 2: "seq2"}, {0: "batch", 2: "seq2"})
170170
fake_inputs, _ = make_fake_with_dynamic_dimensions((query, key, value), ds)
171-
print("****", fake_inputs)
172-
epd = torch.export.export(model, fake_inputs) # , dynamic_shapes=use_dyn_not_str(ds))
173-
print(epq)
171+
epd = torch.export.export(model, fake_inputs, dynamic_shapes=use_dyn_not_str(ds))
174172
got = epd.module()(query, key, value)
175173
self.assertEqualArray(expected, got)
176174

onnx_diagnostic/export/shape_helper.py

Lines changed: 26 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Dict, List, Set, Optional, Tuple, Union
22
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3-
from ..helpers.fake_tensor_helper import fake_reshape
43
from .dynamic_shapes import ModelInputs
54

65

@@ -203,14 +202,14 @@ def guess_dynamic_shapes_from_inputs(
203202

204203

205204
def make_fake_with_dynamic_dimensions(
206-
x: Any,
207-
dynamic_shapes: Any,
208-
fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
209-
) -> Tuple[Any, "FakeTensorMode"]: # noqa: F821
205+
x: Any, dynamic_shapes: Any, context: Optional["FakeTensorContext"] = None # noqa: F821
206+
) -> Tuple[Any, "FakeTensorContext"]: # noqa: F821
210207
"""
211208
Replaces all tensors by fake tensor respecting the same
212209
constraints as the following dynamic shapes.
213210
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
211+
Parameter ``existing`` is used to reused the same object when the dynamic
212+
dimension is given the same name as another one.
214213
215214
A simple tensor:
216215
@@ -227,6 +226,24 @@ def make_fake_with_dynamic_dimensions(
227226
)
228227
print(inputs)
229228
229+
Two tensors:
230+
231+
.. runpython::
232+
:showcode:
233+
234+
import torch
235+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
236+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
237+
238+
inputs, _ = make_fake_with_dynamic_dimensions(
239+
(
240+
torch.rand((2, 3, 4, 5), dtype=torch.float32),
241+
torch.rand((2, 3, 4, 5), dtype=torch.float32),
242+
),
243+
({0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}),
244+
)
245+
print(inputs)
246+
230247
With a cache:
231248
232249
.. runpython::
@@ -271,68 +288,9 @@ def make_fake_with_dynamic_dimensions(
271288
"""
272289
if x is None:
273290
return None, None
274-
if fake_mode is None:
275-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
276-
from torch._subclasses.fake_tensor import FakeTensorMode
291+
if context is None:
292+
from ..helpers.fake_tensor_helper import FakeTensorContext
277293

278-
shape_env = ShapeEnv()
279-
fake_mode = FakeTensorMode(shape_env=shape_env)
294+
context = FakeTensorContext()
280295

281-
if isinstance(x, (list, tuple)):
282-
return (
283-
x.__class__(
284-
[
285-
make_fake_with_dynamic_dimensions(
286-
i, fake_mode=fake_mode, dynamic_shapes=ds
287-
)[0]
288-
for i, ds in zip(x, dynamic_shapes)
289-
]
290-
),
291-
fake_mode,
292-
)
293-
if isinstance(x, dict):
294-
return {
295-
k: make_fake_with_dynamic_dimensions(
296-
v, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[k]
297-
)[0]
298-
for k, v in x.items()
299-
}, fake_mode
300-
301-
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
302-
assert hasattr(x, "layers"), (
303-
f"Une more recent version of transformers (>=4.55), "
304-
f"'layers' not found in class {type(x)}"
305-
)
306-
assert isinstance(dynamic_shapes, list) and (
307-
not dynamic_shapes or not isinstance(dynamic_shapes[0], list)
308-
), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
309-
for il, layer in enumerate(x.layers):
310-
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
311-
f"Une more recent version of transformers (>=4.55), 'layers' "
312-
f"not found in class {type(layer)} ({dir(layer)})"
313-
)
314-
layer.keys = make_fake_with_dynamic_dimensions(
315-
layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[il * 2]
316-
)[0]
317-
layer.values = make_fake_with_dynamic_dimensions(
318-
layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[il * 2 + 1]
319-
)[0]
320-
return x, fake_mode
321-
if x.__class__.__name__ == "EncoderDecoderCache":
322-
make_fake_with_dynamic_dimensions(
323-
x.self_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0]
324-
)
325-
make_fake_with_dynamic_dimensions(
326-
x.cross_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1]
327-
)
328-
return x, fake_mode
329-
if hasattr(x, "shape"):
330-
t = fake_reshape(x, dynamic_shapes, fake_mode=fake_mode)
331-
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
332-
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
333-
return t, fake_mode
334-
from ..helpers import string_type
335-
336-
raise TypeError(
337-
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
338-
)
296+
return context.make_fake_with_dynamic_dimensions(x, dynamic_shapes), context

0 commit comments

Comments
 (0)