Skip to content

Commit f6c684e

Browse files
committed
last changes
1 parent 9e47e33 commit f6c684e

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7-
* :pr:`304`: improves side-by-side comparison
7+
* :pr:`304`, :pr:`306`: improves side-by-side comparison
88

99
0.8.2
1010
+++++

onnx_diagnostic/_command_lines_parser.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,10 +1114,20 @@ def get_parser_sbs() -> ArgumentParser:
11141114
the exported onnx model. It assumes some names are common.
11151115
The execution of the exported program and the onnx model
11161116
are done in parallel. The device is the one used to store the
1117-
model and the inputs.s
1117+
model and the inputs.
1118+
Where do discrepancies start? This function tries to answer that question.
1119+
"""
1120+
),
1121+
epilog=textwrap.dedent(
1122+
"""
1123+
The command line expects the following files to be saved with
1124+
the following function. inputs is a dictionary of the input of the model.
1125+
1126+
- torch.export.save(ep: torch.export.ExportedProgram)
1127+
- torch.save(**inputs)
1128+
- onnx.save(...)
11181129
"""
11191130
),
1120-
epilog="Where do discrepancies start? This function tries to answer that question.",
11211131
)
11221132
parser.add_argument(
11231133
"-i",
@@ -1244,7 +1254,7 @@ def _size(name):
12441254
data.append(obs)
12451255
if (
12461256
obs.onnx_op_type != "initializer"
1247-
and onnx.ep_target != "placeholder"
1257+
and obs.ep_target != "placeholder"
12481258
and len(data) % ratio == 0
12491259
):
12501260
df = pandas.DataFrame(data).apply(

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,11 @@ def _loop_cmp(
674674
f"[run_aligned] run onx.graph.node[{i_onnx}]: "
675675
f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
676676
)
677+
elif verbose == 1:
678+
loop.set_description(
679+
f"ep {i}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} "
680+
f"mapped {yielded_nodes} maxabs {max_abs:1.5f}"
681+
)
677682
ref = run_cls(node, **run_cls_kwargs)
678683
feeds = {k: onnx_results[k] for k in node.input}
679684
res = ref.run(None, feeds) # type: ignore[attr-defined]
@@ -700,7 +705,8 @@ def _loop_cmp(
700705
f"res={string_type(res, with_device=True, with_shape=True)}, "
701706
f"node is {pretty_onnx(node)}"
702707
)
703-
for o, r in zip(node.output, res):
708+
node_output = [o for o in node.output if o]
709+
for o, r in zip(node_output, res):
704710
tmp = _loop_cmp(
705711
mapping_onnx_to_torch,
706712
torch_results,
@@ -739,7 +745,8 @@ def _loop_cmp(
739745
ref = run_cls(node, **run_cls_kwargs)
740746
feeds = {k: onnx_results[k] for k in node.input}
741747
res = ref.run(None, feeds) # type: ignore[attr-defined]
742-
for o, r in zip(node.output, res):
748+
node_output = [o for o in node.output if o]
749+
for o, r in zip(node_output, res):
743750
tmp = _loop_cmp(
744751
mapping_onnx_to_torch,
745752
torch_results,

0 commit comments

Comments
 (0)