@@ -1140,6 +1140,14 @@ def get_parser_sbs() -> ArgumentParser:
11401140 - torch.export.save(ep: torch.export.ExportedProgram)
11411141 - torch.save(**inputs)
11421142 - onnx.save(...)
1143+
1144+ The Replay functionality is just a way to investigates a part of a model.
1145+ It saves torch and onnx inputs, the torch outputs, and the minimal onnx model
1146+ which shares its inputs with the exported program.
1147+ This is used to investigate the discrepancies between the torch
1148+ model (through the exported program) and its onnx conversion.
1149+ This functionality dumps everything it can to disk
1150+ so that it be replayed in a separate process.
11431151 """
11441152 ),
11451153 )
@@ -1222,10 +1230,33 @@ def get_parser_sbs() -> ArgumentParser:
12221230 ),
12231231 )
12241232 parser .add_argument (
1225- "--gemmlinear" ,
1226- action = BooleanOptionalAction ,
1227- default = False ,
1228- help = "Replaces Gemm(A,X.T,B) by torch...linear(A,X,B) on onnx side" ,
1233+ "-s" ,
1234+ "--replay-threshold" ,
1235+ type = float ,
1236+ required = False ,
1237+ default = 1e6 ,
1238+ help = "Triggers the replay if the discrepancies are higher than this value." ,
1239+ )
1240+ parser .add_argument (
1241+ "-n" ,
1242+ "--replay-names" ,
1243+ required = False ,
1244+ default = "" ,
1245+ help = "Triggers the replay if a result name is in this set of values (comma separated)" ,
1246+ )
1247+ parser .add_argument (
1248+ "-t" ,
1249+ "--replay-op-types" ,
1250+ required = False ,
1251+ default = "" ,
1252+ help = "Triggers the replay if an onnx type is in this set of values (comma separated)" ,
1253+ )
1254+ parser .add_argument (
1255+ "-f" ,
1256+ "--replay-folder" ,
1257+ required = False ,
1258+ default = "replay" ,
1259+ help = "If the replay is triggered, this defines the folder where everything is dumped." ,
12291260 )
12301261
12311262 return parser
@@ -1235,7 +1266,7 @@ def _cmd_sbs(argv: List[Any]):
12351266 import pandas
12361267 import torch
12371268 from .helpers import flatten_object , max_diff , string_diff , string_type
1238- from .torch_onnx .sbs import run_aligned
1269+ from .torch_onnx .sbs import run_aligned , ReplayConfiguration
12391270 from .reference import OnnxruntimeEvaluator
12401271
12411272 parser = get_parser_sbs ()
@@ -1306,6 +1337,17 @@ def _size(name):
13061337 onx = onnx .load (args .onnx )
13071338 print (f"-- done in { time .perf_counter () - begin :1.1f} s" )
13081339
1340+ replay_configuration = None
1341+ if args .replay_threshold < 1e6 or args .replay_names or args .replay_op_types :
1342+ replay_configuration = ReplayConfiguration (
1343+ threshold = args .replay_threshold ,
1344+ selected_names = set (args .replay_names .split ("," )) if args .replay_names else None ,
1345+ selected_op_types = (
1346+ set (args .replay_op_types .split ("," )) if args .replay_op_types else None
1347+ ),
1348+ dump_folder = args .replay_folder ,
1349+ )
1350+
13091351 print ("-- starts side-by-side" )
13101352 ratio = int (args .ratio )
13111353 data = []
@@ -1321,6 +1363,7 @@ def _size(name):
13211363 use_tensor = True ,
13221364 reset_names = args .reset .split ("," ),
13231365 exc = False ,
1366+ replay_configuration = replay_configuration ,
13241367 ):
13251368 data .append (obs )
13261369 if (
0 commit comments