11import unittest
2+ import onnx
23from onnx_diagnostic .ext_test_case import (
34 ExtTestCase ,
45 hide_stdout ,
910from onnx_diagnostic .reference import ExtendedReferenceEvaluator , OnnxruntimeEvaluator
1011from onnx_diagnostic .torch_export_patches .patch_inputs import use_dyn_not_str
1112from onnx_diagnostic .torch_onnx .sbs import run_aligned
12-
13- try :
14- from experimental_experiment .torch_interpreter import to_onnx
15- except ImportError :
16- to_onnx = None
13+ from onnx_diagnostic .export .api import to_onnx
1714
1815
1916class TestSideBySide (ExtTestCase ):
@@ -41,7 +38,7 @@ def forward(self, x):
4138 ep = self .torch .export .export (
4239 Model (), (x ,), dynamic_shapes = ({0 : self .torch .export .Dim ("batch" )},)
4340 )
44- onx = to_onnx (ep )
41+ onx = to_onnx (ep , exporter = "custom" ). model_proto
4542 results = list (
4643 run_aligned (
4744 ep ,
@@ -71,10 +68,12 @@ def forward(self, x):
7168 ep = self .torch .export .export (
7269 Model (), (x ,), dynamic_shapes = ({0 : self .torch .export .Dim ("batch" )},)
7370 )
74- epo = self .torch .onnx .export (
75- ep , (x ,), dynamic_shapes = ({0 : self .torch .export .Dim ("batch" )},), dynamo = True
76- )
77- onx = epo .model_proto
71+ onx = to_onnx (
72+ ep ,
73+ (x ,),
74+ dynamic_shapes = ({0 : self .torch .export .Dim ("batch" )},),
75+ exporter = "onnx-dynamo" ,
76+ ).model_proto
7877 results = list (
7978 run_aligned (
8079 ep ,
@@ -105,9 +104,7 @@ def forward(self, x):
105104 ep = self .torch .export .export (
106105 Model (), (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
107106 )
108- epo = self .torch .onnx .export (
109- Model (), (), kwargs = inputs , dynamic_shapes = ds , dynamo = True
110- )
107+ epo = to_onnx (Model (), (), kwargs = inputs , dynamic_shapes = ds , exporter = "onnx-dynamo" )
111108 onx = epo .model_proto
112109 results = list (
113110 run_aligned (
@@ -139,7 +136,7 @@ def forward(self, x):
139136 ep = self .torch .export .export (
140137 Model (), (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
141138 )
142- onx = to_onnx (ep )
139+ onx = to_onnx (ep , exporter = "custom" ). model_proto
143140 results = list (
144141 run_aligned (
145142 ep ,
@@ -170,7 +167,7 @@ def forward(self, x):
170167 ep = self .torch .export .export (
171168 Model (), (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
172169 )
173- onx = to_onnx (ep )
170+ onx = to_onnx (ep , exporter = "custom" ). model_proto
174171 results = list (
175172 run_aligned (
176173 ep ,
@@ -204,7 +201,7 @@ def forward(self, x):
204201 ep = self .torch .export .export (
205202 Model (), (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
206203 )
207- onx = to_onnx (ep )
204+ onx = to_onnx (ep , exporter = "custom" ). model_proto
208205 results = list (
209206 run_aligned (
210207 ep ,
@@ -240,7 +237,7 @@ def forward(self, x):
240237 ep = self .torch .export .export (
241238 Model (), (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
242239 )
243- onx = to_onnx (ep )
240+ onx = to_onnx (ep , exporter = "custom" ). model_proto
244241 results = list (
245242 run_aligned (
246243 ep ,
@@ -275,7 +272,7 @@ def forward(self, x):
275272 ep = self .torch .export .export (
276273 Model (), (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
277274 )
278- onx = to_onnx (ep )
275+ onx = to_onnx (ep , exporter = "custom" ). model_proto
279276 results = list (
280277 run_aligned (
281278 ep ,
@@ -291,6 +288,45 @@ def forward(self, x):
291288 self .assertEqual (len (results ), 7 )
292289 self .assertEqual ([r [- 1 ].get ("dev" , 0 ) for r in results ], [0 , 0 , 0 , 0 , 0 , 0 , 0 ])
293290
291+ @hide_stdout ()
292+ @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
293+ def test_sbs_model_with_weights (self ):
294+ torch = self .torch
295+
296+ class Model (self .torch .nn .Module ):
297+ def __init__ (self ):
298+ super (Model , self ).__init__ ()
299+ self .fc1 = torch .nn .Linear (10 , 32 ) # input size 10 → hidden size 32
300+ self .relu = torch .nn .ReLU ()
301+ self .fc2 = torch .nn .Linear (32 , 1 ) # hidden → output
302+
303+ def forward (self , x ):
304+ x = self .relu (self .fc1 (x ))
305+ x = self .fc2 (x )
306+ return x
307+
308+ inputs = dict (x = self .torch .randn ((5 , 10 )))
309+ ds = dict (x = {0 : "batch" })
310+ Model ()(** inputs )
311+ ep = self .torch .export .export (
312+ Model (), (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
313+ )
314+ filename = self .get_dump_file ("test_sbs_model_with_weights.onnx" )
315+ to_onnx (ep , exporter = "custom" , filename = filename )
316+ onx = onnx .load (filename )
317+ results = list (
318+ run_aligned (
319+ ep ,
320+ onx ,
321+ kwargs = inputs ,
322+ run_cls = OnnxruntimeEvaluator ,
323+ verbose = 11 ,
324+ use_tensor = True ,
325+ ),
326+ )
327+ self .assertEqual (len (results ), 7 )
328+ self .assertEqual ([r [- 1 ].get ("dev" , 0 ) for r in results ], [0 , 0 , 0 , 0 , 0 , 0 , 0 ])
329+
294330
295331if __name__ == "__main__" :
296332 unittest .main (verbosity = 2 )
0 commit comments