Skip to content

Commit 19fb4b7

Browse files
committed
stat
1 parent ebc2cd3 commit 19fb4b7

File tree

2 files changed

+157
-10
lines changed

2 files changed

+157
-10
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def forward(self, x, y):
2121
self.assertEqual(mi.name, "main")
2222
self.assertEqual(mi.true_model_name, "Model")
2323
self.assertEqual(mi.full_name, "main:Model")
24+
self.assertEqual(mi.module_name_type, "type(main)=Model")
2425

2526
def test_guess_dynamic_shapes_none(self):
2627
class Model(torch.nn.Module):
@@ -35,7 +36,7 @@ def forward(self, x, y):
3536

3637
mi = ModelInputs(Model(), [])
3738
ds = mi.guess_dynamic_shapes()
38-
self.assertEmpty(ds)
39+
self.assertEqual(ds, ((), {}))
3940

4041
def test_guess_dynamic_shapes_1args(self):
4142
class Model(torch.nn.Module):
@@ -85,6 +86,9 @@ def forward(self, x, y):
8586
{},
8687
),
8788
)
89+
self.assertEqual(
90+
(({}, {}), {}), ModelInputs(Model(), inputs[:1]).guess_dynamic_shapes()
91+
)
8892

8993
def test_guess_dynamic_shapes_kwargs(self):
9094
class Model(torch.nn.Module):
@@ -113,6 +117,9 @@ def forward(self, x=None, y=None):
113117
},
114118
),
115119
)
120+
self.assertEqual(
121+
((), {"x": {}, "y": {}}), ModelInputs(Model(), inputs[:1]).guess_dynamic_shapes()
122+
)
116123

117124
def test_guess_dynamic_shapes_args_kwargs(self):
118125
class Model(torch.nn.Module):
@@ -139,6 +146,9 @@ def forward(self, x, y=None):
139146
{"y": {1: torch.export.Dim.DYNAMIC}},
140147
),
141148
)
149+
self.assertEqual(
150+
(({},), {"y": {}}), ModelInputs(Model(), inputs[:1]).guess_dynamic_shapes()
151+
)
142152

143153
def test_guess_dynamic_shapes_kwargs_as_kwargs(self):
144154
class Model(torch.nn.Module):
@@ -162,6 +172,9 @@ def forward(self, **kwargs):
162172
self.assertEqual(ds, (tuple(), {"x": {0: torch.export.Dim.DYNAMIC}}))
163173
_a, _kw, ds = mi.move_to_kwargs(*mi.inputs[0], ds)
164174
self.assertEqual(ds, (tuple(), {"kwargs": {"x": {0: torch.export.Dim.DYNAMIC}}}))
175+
self.assertEqual(
176+
((), {"x": {}}), ModelInputs(Model(), inputs[:1]).guess_dynamic_shapes()
177+
)
165178

166179

167180
if __name__ == "__main__":

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 143 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,126 @@
66

77

88
class ModelInputs:
9-
""" """
9+
"""
10+
Wraps a model and a couple of sets of valid inputs.
11+
Based on that information, the class is able to infer the dynamic shapes
12+
for :func:`torch.export.export`.
13+
14+
:param model: model to export
15+
:param inputs: list of valid set of inputs
16+
:param level: if this module is a submodule, it is the level of submodule
17+
:param method_name: by default, the forward method is processed but it
18+
could be another one
19+
:param name: a name, mostly for debugging purposes
20+
21+
Examples:
22+
23+
**args**
24+
25+
.. runpython::
26+
:showcode:
27+
28+
import pprint
29+
import torch
30+
from onnx_diagnostic.export import ModelInputs
31+
32+
33+
class Model(torch.nn.Module):
34+
def forward(self, x, y):
35+
return x + y
36+
37+
38+
model = Model()
39+
x = torch.randn((5, 6))
40+
y = torch.randn((1, 6))
41+
model(x, y) # to check it works
42+
43+
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
44+
mi = ModelInputs(Model(), inputs)
45+
ds = mi.guess_dynamic_shapes()
46+
pprint.pprint(ds)
47+
48+
import pprint
49+
import torch
50+
from onnx_diagnostic.export import ModelInputs
51+
52+
**kwargs**
53+
54+
.. runpython::
55+
:showcode:
56+
57+
class Model(torch.nn.Module):
58+
def forward(self, x, y):
59+
return x + y
60+
61+
62+
model = Model()
63+
x = torch.randn((5, 6))
64+
y = torch.randn((1, 6))
65+
model(x=x, y=y) # to check it works
66+
67+
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
68+
mi = ModelInputs(Model(), inputs)
69+
ds = mi.guess_dynamic_shapes()
70+
pprint.pprint(ds)
71+
72+
import pprint
73+
import torch
74+
from onnx_diagnostic.export import ModelInputs
75+
76+
**and and kwargs**
77+
78+
.. runpython::
79+
:showcode:
80+
81+
class Model(torch.nn.Module):
82+
def forward(self, x, y):
83+
return x + y
84+
85+
86+
model = Model()
87+
x = torch.randn((5, 6))
88+
y = torch.randn((1, 6))
89+
model(x, y=y) # to check it works
90+
91+
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
92+
mi = ModelInputs(Model(), inputs)
93+
ds = mi.guess_dynamic_shapes()
94+
pprint.pprint(ds)
95+
96+
:func:`torch.export.export` does not like dynamic shapes defined both as args and kwargs.
97+
kwargs must be used. ``move_to_kwargs`` modifies the inputs and the dynamic shapes
98+
to make the model and the given inputs exportable.
99+
100+
.. runpython::
101+
:showcode:
102+
103+
import pprint
104+
import torch
105+
from onnx_diagnostic.export import ModelInputs
106+
from onnx_diagnostic.helpers import string_type
107+
108+
109+
class Model(torch.nn.Module):
110+
def forward(self, x, y):
111+
return x + y
112+
113+
114+
model = Model()
115+
x = torch.randn((5, 6))
116+
y = torch.randn((1, 6))
117+
model(x, y=y) # to check it works
118+
119+
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
120+
mi = ModelInputs(Model(), inputs)
121+
ds = mi.guess_dynamic_shapes()
122+
123+
a, kw, nds = mi.move_to_kwargs(*mi.inputs[0], ds)
124+
print("moved args:", string_type(a, with_shape=True))
125+
print("moved kwargs:", string_type(kw, with_shape=True))
126+
print("dynamic shapes:")
127+
pprint.pprint(nds)
128+
"""
10129

11130
def __init__(
12131
self,
@@ -70,7 +189,10 @@ def process_inputs(
70189
List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
71190
],
72191
) -> List[Tuple[Tuple[Any, ...], Dict[str, Any]]]:
73-
""" """
192+
"""
193+
Transforms a list of valid inputs, list of args, list of kwargs or list of both
194+
into a list of (args, kwargs).
195+
"""
74196
if not isinstance(inputs, list):
75197
raise ValueError(
76198
f"inputs should be specifed as a list of sets of "
@@ -111,7 +233,14 @@ def full_name(self):
111233
return f"{self.name}:{self.true_model_name}"
112234
return f"{self.name}:{self.true_model_name}.{self.method_name}"
113235

114-
def guess_dynamic_dimensions(self, *tensors) -> Any:
236+
@property
237+
def module_name_type(self):
238+
"Returns name and module type."
239+
if self.method_name == "forward":
240+
return f"type({self.name})={self.true_model_name}"
241+
return f"type({self.name})={self.true_model_name}.{self.method_name}"
242+
243+
def guess_dynamic_dimensions(self, *tensors) -> Dict[int, Any]:
115244
"""Infers the dynamic dimension from multiple shapes."""
116245
if len(tensors) == 1:
117246
return {}
@@ -122,7 +251,7 @@ def guess_dynamic_dimensions(self, *tensors) -> Any:
122251
f"shapes={shapes} for module {self.name!r}, "
123252
f"class={self.true_model_name!r}"
124253
)
125-
dynamic = torch.export.Dim.DYNAMIC
254+
dynamic: Any = torch.export.Dim.DYNAMIC # type: ignore
126255
rk = set_length.pop()
127256
res = {}
128257
for i in range(rk):
@@ -198,18 +327,20 @@ def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None)
198327
f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}"
199328
)
200329

201-
def guess_dynamic_shapes(self) -> Any:
330+
def guess_dynamic_shapes(
331+
self,
332+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
202333
"""
203334
Guesses the dynamic shapes for that module from two execution.
204335
If there is only one execution, then that would be static dimensions.
205336
"""
206337
if len(self.inputs) == 0:
207338
# No inputs, unable to guess.
208-
return None
339+
return (tuple(), {})
209340
if len(self.inputs) == 1:
210341
# No dynamic shapes.
211342
return tuple(self.guess_dynamic_shape_object(a) for a in self.inputs[0][0]), {
212-
k: self.guess_dynamic_shape_object(v) for k, v in self.inputs[0][1]
343+
k: self.guess_dynamic_shape_object(v) for k, v in self.inputs[0][1].items()
213344
}
214345

215346
# Otherwise.
@@ -241,8 +372,11 @@ def guess_dynamic_shapes(self) -> Any:
241372
return tuple(args), kwargs
242373

243374
def move_to_kwargs(
244-
self, args: Tuple[Any, ...], kwargs: Dict[str, Any], dynamic_shapes: Any
245-
) -> Tuple[Tuple[Any, ...], Dict[str, Any], Dict[str, Any]]:
375+
self,
376+
args: Tuple[Any, ...],
377+
kwargs: Dict[str, Any],
378+
dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]],
379+
) -> Tuple[Tuple[Any, ...], Dict[str, Any], Tuple[Tuple[Any, ...], Dict[str, Any]]]:
246380
"""
247381
Uses the signatures to move positional arguments (args) to named arguments (kwargs)
248382
with the corresponding dynamic shapes.

0 commit comments

Comments
 (0)