Skip to content

Commit 1d3eddb

Browse files
committed
fix shapes
1 parent f483e37 commit 1d3eddb

File tree

7 files changed

+828
-717
lines changed

7 files changed

+828
-717
lines changed

_doc/api/export/dynamic_shapes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ onnx_diagnostic.export.dynamic_shapes
55
.. automodule:: onnx_diagnostic.export.dynamic_shapes
66
:members:
77
:no-undoc-members:
8-
:exclude-members: onnx_diagnostic.export.dynamic_shapes
8+
:exclude-members: CoupleInputsDynamicShapes, ModelInputs

_doc/api/export/index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ onnx_diagnostic.export
77

88
dynamic_shapes
99

10+
CoupleInputsDynamicShapes
11+
+++++++++++++++++++++++++
12+
13+
.. autoclass:: onnx_diagnostic.export.CoupleInputsDynamicShapes
14+
:members:
15+
1016
ModelInputs
1117
+++++++++++
1218

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from onnx_diagnostic.ext_test_case import ExtTestCase
44
from onnx_diagnostic.helpers import string_type
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
6-
from onnx_diagnostic.export import ModelInputs
7-
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
6+
from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes
87
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
98

109

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.helpers import string_type
5+
from onnx_diagnostic.helpers.cache_helper import (
6+
make_dynamic_cache,
7+
flatten_unflatten_for_dynamic_shapes,
8+
)
9+
from onnx_diagnostic.export import ModelInputs
10+
11+
12+
class TestSerialization(ExtTestCase):
13+
def _get_cache(self, n_layers=2, bsize=2, nheads=4, slen=1, dim=7):
14+
return make_dynamic_cache(
15+
[
16+
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
17+
for i in range(n_layers)
18+
]
19+
)
20+
21+
def test_dynamic_cache(self):
22+
class Model(torch.nn.Module):
23+
def forward(self, cache):
24+
return cache.key_cache[0]
25+
26+
cache = self._get_cache()
27+
flat_unflat = flatten_unflatten_for_dynamic_shapes(cache)
28+
s = string_type(flat_unflat, with_shape=True)
29+
self.assertEqual(s, "#2[#2[T1s2x4x1x7,T1s2x4x1x7],#2[T1s2x4x1x7,T1s2x4x1x7]]")
30+
DYN = torch.export.Dim.DYNAMIC
31+
ds = {0: DYN, 1: DYN, 3: DYN}
32+
dynamic_shapes = ([[ds, ds], [ds, ds]],)
33+
exp = torch.export.export(Model(), (cache,), dynamic_shapes=dynamic_shapes)
34+
self.assertNotEmpty(exp)
35+
36+
def test_dynamic_cache_guess_static(self):
37+
class Model(torch.nn.Module):
38+
def forward(self, cache):
39+
return cache.key_cache[0]
40+
41+
cache = self._get_cache()
42+
md = ModelInputs(Model(), [(cache,)])
43+
guessed = md.guess_dynamic_shapes()
44+
self.assertEqual(guessed, (([[{}, {}], [{}, {}]],), {}))
45+
46+
def test_dynamic_cache_guess_dynamic(self):
47+
class Model(torch.nn.Module):
48+
def forward(self, cache):
49+
return cache.key_cache[0]
50+
51+
md = ModelInputs(
52+
Model(), [(self._get_cache(),), (self._get_cache(bsize=3, nheads=5),)]
53+
)
54+
guessed = md.guess_dynamic_shapes()
55+
DYN = torch.export.Dim.DYNAMIC
56+
self.assertEqual(
57+
guessed,
58+
(
59+
(
60+
[
61+
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
62+
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
63+
],
64+
),
65+
{},
66+
),
67+
)
68+
69+
70+
if __name__ == "__main__":
71+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)