@@ -1199,6 +1199,7 @@ def assert_onnx_disc(
11991199 expected : Optional [Any ] = None ,
12001200 use_ort : bool = False ,
12011201 ort_optimized_graph : bool = False ,
1202+ ep : Optional [Union ["torch.export.ExportedProgram" , str ]] = None , # noqa: F821
12021203 ** kwargs ,
12031204 ):
12041205 """
@@ -1218,6 +1219,7 @@ def assert_onnx_disc(
12181219 :param copy_inputs: to copy the inputs
12191220 :param use_ort: use :class:`onnxruntime.InferenceSession`
12201221 :param ort_optimized_graph: dumps the optimized onnxruntime graph
1222+ :param ep: exported program (or saved exported program)
12211223 :param kwargs: arguments sent to
12221224 :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
12231225 """
@@ -1245,6 +1247,7 @@ def assert_onnx_disc(
12451247 print (f"[{ vname } ] file size { os .stat (name ).st_size // 2 ** 10 :1.3f} kb" )
12461248 if verbose :
12471249 print (f"[{ vname } ] make feeds { string_type (inputs , ** kws )} " )
1250+
12481251 if use_ort :
12491252 assert isinstance (
12501253 proto , onnx .ModelProto
@@ -1275,6 +1278,7 @@ def assert_onnx_disc(
12751278 got = sess .run (None , feeds )
12761279 if verbose :
12771280 print (f"[{ vname } ] compute expected values" )
1281+
12781282 if expected is None :
12791283 if copy_inputs :
12801284 expected = (
@@ -1284,9 +1288,45 @@ def assert_onnx_disc(
12841288 )
12851289 else :
12861290 expected = model (* inputs ) if isinstance (inputs , tuple ) else model (** inputs )
1291+
12871292 if verbose :
12881293 print (f"[{ vname } ] expected { string_type (expected , ** kws )} " )
12891294 print (f"[{ vname } ] obtained { string_type (got , ** kws )} " )
1295+
1296+ if ep :
1297+ if isinstance (ep , str ):
1298+ if verbose :
1299+ print (f"[{ vname } ] load exported program { ep !r} " )
1300+ import torch
1301+
1302+ ep = torch .export .load (ep )
1303+ ep_inputs = copy .deepcopy (inputs ) if copy_inputs else inputs
1304+ ep_model = ep .module ()
1305+ ep_expected = (
1306+ ep_model (* copy .deepcopy (ep_inputs ))
1307+ if isinstance (ep_inputs , tuple )
1308+ else ep_model (** copy .deepcopy (ep_inputs ))
1309+ )
1310+ if verbose :
1311+ print (f"[{ vname } ] ep_expected { string_type (ep_expected , ** kws )} " )
1312+ ep_diff = max_diff (expected , ep_expected )
1313+ if verbose :
1314+ print (f"[{ vname } ] ep_diff { string_diff (ep_diff )} " )
1315+ assert (
1316+ isinstance (ep_diff ["abs" ], float )
1317+ and isinstance (ep_diff ["rel" ], float )
1318+ and not numpy .isnan (ep_diff ["abs" ])
1319+ and ep_diff ["abs" ] <= atol
1320+ and not numpy .isnan (ep_diff ["rel" ])
1321+ and ep_diff ["rel" ] <= rtol
1322+ ), (
1323+ f"discrepancies in { test_name !r} between the model "
1324+ f"and the exported model diff={ string_diff (ep_diff )} "
1325+ )
1326+ ep_nx_diff = max_diff (ep_expected , got , flatten = True )
1327+ if verbose :
1328+ print (f"[{ vname } ] ep_nx_diff { string_diff (ep_nx_diff )} " )
1329+
12901330 diff = max_diff (expected , got , flatten = True )
12911331 if verbose :
12921332 print (f"[{ vname } ] diff { string_diff (diff )} " )
@@ -1297,7 +1337,10 @@ def assert_onnx_disc(
12971337 and diff ["abs" ] <= atol
12981338 and not numpy .isnan (diff ["rel" ])
12991339 and diff ["rel" ] <= rtol
1300- ), f"discrepancies in { test_name !r} , diff={ string_diff (diff )} "
1340+ ), (
1341+ f"discrepancies in { test_name !r} between the model and "
1342+ f"the onnx model diff={ string_diff (diff )} "
1343+ )
13011344
13021345 def _debug (self ):
13031346 "Tells if DEBUG=1 is set up."
@@ -1308,6 +1351,11 @@ def string_type(self, *args, **kwargs):
13081351
13091352 return string_type (* args , ** kwargs )
13101353
1354+ def max_diff (self , * args , ** kwargs ):
1355+ from .helpers import max_diff
1356+
1357+ return max_diff (* args , ** kwargs )
1358+
13111359 def subloop (self , * args , verbose : int = 0 ):
13121360 "Loops over elements and calls :meth:`unittests.TestCase.subTest`."
13131361 if len (args ) == 1 :
0 commit comments