Skip to content

Commit 8313a1b

Browse files
committed
fix shape
1 parent 2920aa5 commit 8313a1b

File tree

3 files changed

+138
-18
lines changed

3 files changed

+138
-18
lines changed

_unittests/ut_export/test_shape_helper.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
44
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
55
from onnx_diagnostic.torch_export_patches import torch_export_patches
6+
from onnx_diagnostic.helpers import flatten_object
67
from onnx_diagnostic.helpers.cache_helper import (
78
make_dynamic_cache,
89
make_sliding_window_cache,
@@ -12,11 +13,11 @@
1213
from onnx_diagnostic.export.shape_helper import (
1314
all_dynamic_shapes_from_inputs,
1415
guess_dynamic_shapes_from_inputs,
16+
make_fake_with_dynamic_dimensions,
1517
)
1618

1719

1820
class TestShapeHelper(ExtTestCase):
19-
2021
@requires_transformers("4.52")
2122
@requires_torch("2.7.99")
2223
def test_all_dynamic_shape_from_cache(self):
@@ -184,6 +185,60 @@ def test_guess_dynamic_shapes_from_inputs(self):
184185
guessed,
185186
)
186187

188+
@requires_transformers("4.55")
189+
@requires_torch("2.9")
190+
def test_make_fake_with_dynamic_dimensions_tensor(self):
191+
res = make_fake_with_dynamic_dimensions(
192+
(torch.rand((2, 32, 30, 96), dtype=torch.float16),),
193+
({0: "batch", 2: "cache_length"},),
194+
)
195+
reshaped = res[0][0]
196+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
197+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
198+
self.assertEqual(reshaped.shape[1], 32)
199+
self.assertEqual(reshaped.shape[3], 96)
200+
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
201+
202+
@requires_transformers("4.55")
203+
@requires_torch("2.9")
204+
def test_make_fake_with_dynamic_dimensions_whole(self):
205+
res = make_fake_with_dynamic_dimensions(
206+
dict(
207+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
208+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
209+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
210+
past_key_values=make_dynamic_cache(
211+
[
212+
(
213+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
214+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
215+
),
216+
(
217+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
218+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
219+
),
220+
]
221+
),
222+
),
223+
dynamic_shapes={
224+
"input_ids": {0: "batch", 1: "seq_length"},
225+
"attention_mask": {0: "batch", 1: "cache+seq"},
226+
"position_ids": {0: "batch", 1: "seq_length"},
227+
"past_key_values": [
228+
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
229+
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
230+
],
231+
},
232+
)
233+
flat = flatten_object(res[0], drop_keys=True)
234+
for t in flat:
235+
if len(t.shape) == 4:
236+
self.assertIsInstance(t.shape[0], torch.SymInt)
237+
self.assertIsInstance(t.shape[2], torch.SymInt)
238+
self.assertEqual(t.shape[1], 32)
239+
self.assertEqual(t.shape[3], 96)
240+
self.assertNotEqual(t.shape[0], t.shape[2])
241+
187242

188243
if __name__ == "__main__":
189244
unittest.main(verbosity=2)

onnx_diagnostic/export/shape_helper.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Any, Dict, List, Set, Optional, Tuple, Union
2-
from ..helpers import flatten_object
32
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
4-
from ..helpers.fake_tensor_helper import make_fake
5-
from .dynamic_shapes import ModelInputs, _flatten_dynamic_shapes
3+
from ..helpers.fake_tensor_helper import fake_reshape
4+
from .dynamic_shapes import ModelInputs
65

76

87
def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
@@ -204,10 +203,10 @@ def guess_dynamic_shapes_from_inputs(
204203

205204

206205
def make_fake_with_dynamic_dimensions(
207-
inputs: Any,
206+
x: Any,
208207
dynamic_shapes: Any,
209208
fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
210-
) -> Any: # noqa: F821
209+
) -> Tuple[Any, "FakeTensorMode"]: # noqa: F821
211210
"""
212211
Replaces all tensors by fake tensor respecting the same
213212
constraints as the following dynamic shapes.
@@ -235,19 +234,81 @@ def make_fake_with_dynamic_dimensions(
235234
),
236235
]
237236
),
238-
)
237+
),
238+
dynamic_shapes={
239+
"input_ids": {0: "batch", 1: "seq_length"},
240+
"attention_mask": {0: "batch", 1: "cache+seq"},
241+
"position_ids": {0: "batch", 1: "seq_length"},
242+
"past_key_values": [
243+
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
244+
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
245+
],
246+
},
239247
)
240248
print(inputs)
241249
"""
242-
flat_inputs = flatten_object(inputs, drop_keys=True)
243-
flat_fake, fake_mode = make_fake(flat_inputs, fake_mode=fake_mode)
244-
flat_ds = _flatten_dynamic_shapes(dynamic_shapes)
245-
assert len(flat_inputs) == len(flat_ds), (
246-
f"Mismatch between the number of input tensor {len(flat_inputs)} "
247-
f"and the number of dynamic_shapes {len(flat_ds)}"
250+
if x is None:
251+
return None, None
252+
if fake_mode is None:
253+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
254+
from torch._subclasses.fake_tensor import FakeTensorMode
255+
256+
shape_env = ShapeEnv()
257+
fake_mode = FakeTensorMode(shape_env=shape_env)
258+
259+
if isinstance(x, (list, tuple)):
260+
return (
261+
x.__class__(
262+
[
263+
make_fake_with_dynamic_dimensions(
264+
i, fake_mode=fake_mode, dynamic_shapes=ds
265+
)[0]
266+
for i, ds in zip(x, dynamic_shapes)
267+
]
268+
),
269+
fake_mode,
270+
)
271+
if isinstance(x, dict):
272+
return {
273+
k: make_fake_with_dynamic_dimensions(
274+
v, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[k]
275+
)[0]
276+
for k, v in x.items()
277+
}, fake_mode
278+
279+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
280+
assert hasattr(x, "layers"), (
281+
f"Une more recent version of transformers (>=4.55), "
282+
f"'layers' not found in class {type(x)}"
283+
)
284+
assert (
285+
isinstance(dynamic_shapes, list) and len(dynamic_shapes) == 2
286+
), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
287+
for il, layer in enumerate(x.layers):
288+
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
289+
f"Une more recent version of transformers (>=4.55), 'layers' "
290+
f"not found in class {type(layer)} ({dir(layer)})"
291+
)
292+
layer.keys = make_fake_with_dynamic_dimensions(
293+
layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0][il]
294+
)[0]
295+
layer.values = make_fake_with_dynamic_dimensions(
296+
layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1][il]
297+
)[0]
298+
return x, fake_mode
299+
if x.__class__.__name__ == "EncoderDecoderCache":
300+
make_fake_with_dynamic_dimensions(
301+
x.self_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0]
302+
)
303+
make_fake_with_dynamic_dimensions(
304+
x.cross_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1]
305+
)
306+
return x, fake_mode
307+
if hasattr(x, "shape"):
308+
t = fake_reshape(x, dynamic_shapes, fake_mode=fake_mode)
309+
return t, fake_mode
310+
from . import string_type
311+
312+
raise TypeError(
313+
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
248314
)
249-
flat_reshaped = [
250-
make_fake_with_dynamic_dimensions(t, sh, true_tensor=t, fake_mode=fake_mode)
251-
for t, sh in zip(flat_inputs, flat_fake, flat_ds)
252-
]
253-
return flat_reshaped

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ disable_error_code = ["call-overload", "name-defined", "import-untyped"]
4646
module = ["onnx_diagnostic.ext_test_case"]
4747
disable_error_code = ["arg-type", "assignment", "import-untyped", "misc", "name-defined", "override", "return-value", "truthy-function"]
4848

49+
[[tool.mypy.overrides]]
50+
module = ["onnx_diagnostic.export.shape_helper"]
51+
disable_error_code = ["name-defined"]
52+
4953
[[tool.mypy.overrides]]
5054
module = ["onnx_diagnostic.helpers.args_helper"]
5155
disable_error_code = ["arg-type", "call-overload", "index"]

0 commit comments

Comments
 (0)