66
77
88class ModelInputs :
9- """ """
9+ """
10+ Wraps a model and a couple of sets of valid inputs.
11+ Based on that information, the class is able to infer the dynamic shapes
12+ for :func:`torch.export.export`.
13+
14+ :param model: model to export
15+ :param inputs: list of valid set of inputs
16+ :param level: if this module is a submodule, it is the level of submodule
17+ :param method_name: by default, the forward method is processed but it
18+ could be another one
19+ :param name: a name, mostly for debugging purposes
20+
21+ Examples:
22+
23+ **args**
24+
25+ .. runpython::
26+ :showcode:
27+
28+ import pprint
29+ import torch
30+ from onnx_diagnostic.export import ModelInputs
31+
32+
33+ class Model(torch.nn.Module):
34+ def forward(self, x, y):
35+ return x + y
36+
37+
38+ model = Model()
39+ x = torch.randn((5, 6))
40+ y = torch.randn((1, 6))
41+ model(x, y) # to check it works
42+
43+ inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
44+ mi = ModelInputs(Model(), inputs)
45+ ds = mi.guess_dynamic_shapes()
46+ pprint.pprint(ds)
47+
48+ import pprint
49+ import torch
50+ from onnx_diagnostic.export import ModelInputs
51+
52+ **kwargs**
53+
54+ .. runpython::
55+ :showcode:
56+
57+ class Model(torch.nn.Module):
58+ def forward(self, x, y):
59+ return x + y
60+
61+
62+ model = Model()
63+ x = torch.randn((5, 6))
64+ y = torch.randn((1, 6))
65+ model(x=x, y=y) # to check it works
66+
67+ inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
68+ mi = ModelInputs(Model(), inputs)
69+ ds = mi.guess_dynamic_shapes()
70+ pprint.pprint(ds)
71+
72+ import pprint
73+ import torch
74+ from onnx_diagnostic.export import ModelInputs
75+
76+ **and and kwargs**
77+
78+ .. runpython::
79+ :showcode:
80+
81+ class Model(torch.nn.Module):
82+ def forward(self, x, y):
83+ return x + y
84+
85+
86+ model = Model()
87+ x = torch.randn((5, 6))
88+ y = torch.randn((1, 6))
89+ model(x, y=y) # to check it works
90+
91+ inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
92+ mi = ModelInputs(Model(), inputs)
93+ ds = mi.guess_dynamic_shapes()
94+ pprint.pprint(ds)
95+
96+ :func:`torch.export.export` does not like dynamic shapes defined both as args and kwargs.
97+ kwargs must be used. ``move_to_kwargs`` modifies the inputs and the dynamic shapes
98+ to make the model and the given inputs exportable.
99+
100+ .. runpython::
101+ :showcode:
102+
103+ import pprint
104+ import torch
105+ from onnx_diagnostic.export import ModelInputs
106+ from onnx_diagnostic.helpers import string_type
107+
108+
109+ class Model(torch.nn.Module):
110+ def forward(self, x, y):
111+ return x + y
112+
113+
114+ model = Model()
115+ x = torch.randn((5, 6))
116+ y = torch.randn((1, 6))
117+ model(x, y=y) # to check it works
118+
119+ inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
120+ mi = ModelInputs(Model(), inputs)
121+ ds = mi.guess_dynamic_shapes()
122+
123+ a, kw, nds = mi.move_to_kwargs(*mi.inputs[0], ds)
124+ print("moved args:", string_type(a, with_shape=True))
125+ print("moved kwargs:", string_type(kw, with_shape=True))
126+ print("dynamic shapes:")
127+ pprint.pprint(nds)
128+ """
10129
11130 def __init__ (
12131 self ,
@@ -70,7 +189,10 @@ def process_inputs(
70189 List [Tuple [Tuple [Any , ...], Dict [str , Any ]]],
71190 ],
72191 ) -> List [Tuple [Tuple [Any , ...], Dict [str , Any ]]]:
73- """ """
192+ """
193+ Transforms a list of valid inputs, list of args, list of kwargs or list of both
194+ into a list of (args, kwargs).
195+ """
74196 if not isinstance (inputs , list ):
75197 raise ValueError (
76198 f"inputs should be specifed as a list of sets of "
@@ -111,7 +233,14 @@ def full_name(self):
111233 return f"{ self .name } :{ self .true_model_name } "
112234 return f"{ self .name } :{ self .true_model_name } .{ self .method_name } "
113235
114- def guess_dynamic_dimensions (self , * tensors ) -> Any :
236+ @property
237+ def module_name_type (self ):
238+ "Returns name and module type."
239+ if self .method_name == "forward" :
240+ return f"type({ self .name } )={ self .true_model_name } "
241+ return f"type({ self .name } )={ self .true_model_name } .{ self .method_name } "
242+
243+ def guess_dynamic_dimensions (self , * tensors ) -> Dict [int , Any ]:
115244 """Infers the dynamic dimension from multiple shapes."""
116245 if len (tensors ) == 1 :
117246 return {}
@@ -122,7 +251,7 @@ def guess_dynamic_dimensions(self, *tensors) -> Any:
122251 f"shapes={ shapes } for module { self .name !r} , "
123252 f"class={ self .true_model_name !r} "
124253 )
125- dynamic = torch .export .Dim .DYNAMIC
254+ dynamic : Any = torch .export .Dim .DYNAMIC # type: ignore
126255 rk = set_length .pop ()
127256 res = {}
128257 for i in range (rk ):
@@ -198,18 +327,20 @@ def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None)
198327 f"{ string_type (objs )} { msg () if msg else '' } in { self .module_name_type } "
199328 )
200329
201- def guess_dynamic_shapes (self ) -> Any :
330+ def guess_dynamic_shapes (
331+ self ,
332+ ) -> Tuple [Tuple [Any , ...], Dict [str , Any ]]:
202333 """
203334 Guesses the dynamic shapes for that module from two execution.
204335 If there is only one execution, then that would be static dimensions.
205336 """
206337 if len (self .inputs ) == 0 :
207338 # No inputs, unable to guess.
208- return None
339+ return ( tuple (), {})
209340 if len (self .inputs ) == 1 :
210341 # No dynamic shapes.
211342 return tuple (self .guess_dynamic_shape_object (a ) for a in self .inputs [0 ][0 ]), {
212- k : self .guess_dynamic_shape_object (v ) for k , v in self .inputs [0 ][1 ]
343+ k : self .guess_dynamic_shape_object (v ) for k , v in self .inputs [0 ][1 ]. items ()
213344 }
214345
215346 # Otherwise.
@@ -241,8 +372,11 @@ def guess_dynamic_shapes(self) -> Any:
241372 return tuple (args ), kwargs
242373
243374 def move_to_kwargs (
244- self , args : Tuple [Any , ...], kwargs : Dict [str , Any ], dynamic_shapes : Any
245- ) -> Tuple [Tuple [Any , ...], Dict [str , Any ], Dict [str , Any ]]:
375+ self ,
376+ args : Tuple [Any , ...],
377+ kwargs : Dict [str , Any ],
378+ dynamic_shapes : Tuple [Tuple [Any , ...], Dict [str , Any ]],
379+ ) -> Tuple [Tuple [Any , ...], Dict [str , Any ], Tuple [Tuple [Any , ...], Dict [str , Any ]]]:
246380 """
247381 Uses the signatures to move positional arguments (args) to named arguments (kwargs)
248382 with the corresponding dynamic shapes.
0 commit comments