@@ -1202,13 +1202,26 @@ def get_parser_sbs() -> ArgumentParser:
12021202 required = False ,
12031203 help = "Saves the result in an excel file every <ratio> nodes." ,
12041204 )
1205+ parser .add_argument (
1206+ "--first" ,
1207+ action = BooleanOptionalAction ,
1208+ default = False ,
1209+ help = "First runs the whole model." ,
1210+ )
1211+ parser .add_argument (
1212+ "--gemmlinear" ,
1213+ action = BooleanOptionalAction ,
1214+ default = False ,
1215+ help = "Replaces Gemm(A,X.T,B) by torch...linear(A,X,B) on onnx side" ,
1216+ )
1217+
12051218 return parser
12061219
12071220
12081221def _cmd_sbs (argv : List [Any ]):
12091222 import pandas
12101223 import torch
1211- from .helpers import string_type
1224+ from .helpers import flatten_object , max_diff , string_diff , string_type
12121225 from .torch_onnx .sbs import run_aligned
12131226 from .reference import OnnxruntimeEvaluator
12141227
@@ -1257,6 +1270,23 @@ def _size(name):
12571270 ep = torch .export .load (args .ep )
12581271 print (f"-- done in { time .perf_counter () - begin :1.1f} s" )
12591272
1273+ if args .first :
1274+ print ("-- compare first, run ep" )
1275+ print (f"-- args: { string_type (margs , with_shape = True , with_device = True )} " )
1276+ print (f"-- mkwargs: { string_type (mkwargs , with_shape = True , with_device = True )} " )
1277+ expected = ep .module ()(* margs , ** mkwargs )
1278+ print (f"-- expected: { string_type (expected , with_shape = True , with_device = True )} " )
1279+ sess = OnnxruntimeEvaluator (args .onnx , whole = True )
1280+ onx_inputs = flatten_object ([margs , mkwargs ], drop_keys = True )
1281+ feeds = dict (zip (sess .input_names , onx_inputs ))
1282+ print (f"-- feeds: { string_type (feeds , with_shape = True , with_device = True )} " )
1283+ got = sess .run (None , feeds )
1284+ print (f"-- got: { string_type (got , with_shape = True , with_device = True )} " )
1285+ diff = max_diff (expected , got , hist = [0.1 ])
1286+ print (f"-- diff: { string_diff (diff )} " )
1287+ print ("-- done" )
1288+ del sess
1289+
12601290 print (f"-- load onnx { args .onnx !r} " )
12611291 begin = time .perf_counter ()
12621292 onx = onnx .load (args .onnx )
@@ -1275,6 +1305,7 @@ def _size(name):
12751305 args = margs ,
12761306 kwargs = mkwargs ,
12771307 use_tensor = True ,
1308+ gemmlinear = args .gemmlinear ,
12781309 exc = False ,
12791310 ):
12801311 data .append (obs )
0 commit comments