Skip to content

Commit 049c174

Browse files
committed
add command lines
1 parent 61d3480 commit 049c174

File tree

4 files changed

+58
-7
lines changed

4 files changed

+58
-7
lines changed

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def setUpClass(cls):
2323

2424
def test_run_aligned_record(self):
2525
r = RunAlignedRecord(
26-
ep_id_node=-1,
26+
ep_id_node=1,
2727
onnx_id_node=-1,
2828
ep_name="A",
2929
onnx_name="B",

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def forward(self, x):
112112
input_file = self.get_dump_file("test_h_parser_sbs.inputs.pt")
113113
ep_file = self.get_dump_file("test_h_parser_sbs.ep")
114114
onnx_file = self.get_dump_file("test_h_parser_sbs.model.onnx")
115+
replay_foler = self.get_dump_folder("test_h_parser_sbs.replay")
115116
torch.save(inputs, input_file)
116117
to_onnx(
117118
Model(),
@@ -139,6 +140,10 @@ def forward(self, x):
139140
output,
140141
"-m",
141142
onnx_file,
143+
"-t",
144+
"Gemm",
145+
"-f",
146+
replay_foler,
142147
]
143148
)
144149
text = st.getvalue()

onnx_diagnostic/_command_lines_parser.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,14 @@ def get_parser_sbs() -> ArgumentParser:
11401140
- torch.export.save(ep: torch.export.ExportedProgram)
11411141
- torch.save(**inputs)
11421142
- onnx.save(...)
1143+
1144+
The Replay functionality is just a way to investigates a part of a model.
1145+
It saves torch and onnx inputs, the torch outputs, and the minimal onnx model
1146+
which shares its inputs with the exported program.
1147+
This is used to investigate the discrepancies between the torch
1148+
model (through the exported program) and its onnx conversion.
1149+
This functionality dumps everything it can to disk
1150+
so that it be replayed in a separate process.
11431151
"""
11441152
),
11451153
)
@@ -1222,10 +1230,33 @@ def get_parser_sbs() -> ArgumentParser:
12221230
),
12231231
)
12241232
parser.add_argument(
1225-
"--gemmlinear",
1226-
action=BooleanOptionalAction,
1227-
default=False,
1228-
help="Replaces Gemm(A,X.T,B) by torch...linear(A,X,B) on onnx side",
1233+
"-s",
1234+
"--replay-threshold",
1235+
type=float,
1236+
required=False,
1237+
default=1e6,
1238+
help="Triggers the replay if the discrepancies are higher than this value.",
1239+
)
1240+
parser.add_argument(
1241+
"-n",
1242+
"--replay-names",
1243+
required=False,
1244+
default="",
1245+
help="Triggers the replay if a result name is in this set of values (comma separated)",
1246+
)
1247+
parser.add_argument(
1248+
"-t",
1249+
"--replay-op-types",
1250+
required=False,
1251+
default="",
1252+
help="Triggers the replay if an onnx type is in this set of values (comma separated)",
1253+
)
1254+
parser.add_argument(
1255+
"-f",
1256+
"--replay-folder",
1257+
required=False,
1258+
default="replay",
1259+
help="If the replay is triggered, this defines the folder where everything is dumped.",
12291260
)
12301261

12311262
return parser
@@ -1235,7 +1266,7 @@ def _cmd_sbs(argv: List[Any]):
12351266
import pandas
12361267
import torch
12371268
from .helpers import flatten_object, max_diff, string_diff, string_type
1238-
from .torch_onnx.sbs import run_aligned
1269+
from .torch_onnx.sbs import run_aligned, ReplayConfiguration
12391270
from .reference import OnnxruntimeEvaluator
12401271

12411272
parser = get_parser_sbs()
@@ -1306,6 +1337,17 @@ def _size(name):
13061337
onx = onnx.load(args.onnx)
13071338
print(f"-- done in {time.perf_counter() - begin:1.1f}s")
13081339

1340+
replay_configuration = None
1341+
if args.replay_threshold < 1e6 or args.replay_names or args.replay_op_types:
1342+
replay_configuration = ReplayConfiguration(
1343+
threshold=args.replay_threshold,
1344+
selected_names=set(args.replay_names.split(",")) if args.replay_names else None,
1345+
selected_op_types=(
1346+
set(args.replay_op_types.split(",")) if args.replay_op_types else None
1347+
),
1348+
dump_folder=args.replay_folder,
1349+
)
1350+
13091351
print("-- starts side-by-side")
13101352
ratio = int(args.ratio)
13111353
data = []
@@ -1321,6 +1363,7 @@ def _size(name):
13211363
use_tensor=True,
13221364
reset_names=args.reset.split(","),
13231365
exc=False,
1366+
replay_configuration=replay_configuration,
13241367
):
13251368
data.append(obs)
13261369
if (

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,17 @@ class ReplayConfiguration:
184184
pieces to investigate
185185
:param selected_names: list of results names to dump
186186
:param selected_op_types: list of onnx operators to dump
187-
:param threshold: only keep thoses whose discrepancies is greater than that threshold
187+
:param threshold: only keep those whose discrepancies is greater than that threshold
188188
"""
189189

190190
dump_folder: str
191191
selected_names: Optional[Set[str]] = None
192192
selected_op_types: Optional[Set[str]] = None
193193
threshold: float = 0.1
194194

195+
def __post_init__(self):
196+
assert self.dump_folder, "dump_folder is empty and this is not allowed for the replay"
197+
195198
def select(
196199
self,
197200
name: Optional[str] = None,

0 commit comments

Comments
 (0)