Skip to content

Commit 1eeedc4

Browse files
committed
Better side-by-side
1 parent e8dcd00 commit 1eeedc4

File tree

3 files changed

+187
-168
lines changed

3 files changed

+187
-168
lines changed

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
1212
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
13-
from onnx_diagnostic.torch_onnx.sbs import run_aligned, post_process_run_aligned_obs
13+
from onnx_diagnostic.torch_onnx.sbs import run_aligned, RunAlignedRecord
1414
from onnx_diagnostic.export.api import to_onnx
1515

1616

@@ -21,6 +21,24 @@ def setUpClass(cls):
2121

2222
cls.torch = torch
2323

24+
def test_run_aligned_record(self):
25+
r = RunAlignedRecord(
26+
ep_id_node=-1,
27+
onnx_id_node=-1,
28+
ep_name="A",
29+
onnx_name="B",
30+
ep_target="C",
31+
onnx_op_type="D",
32+
shape_type="E",
33+
err_abs=0.1,
34+
err_rel=0.2,
35+
err_dev=0.3,
36+
err_nan=0.4,
37+
)
38+
sr = str(r)
39+
self.assertIn("RunAlignedRecord(", sr)
40+
self.assertIn("shape_type='E'", sr)
41+
2442
@hide_stdout()
2543
@unittest.skipIf(to_onnx is None, "to_onnx not installed")
2644
@ignore_errors(OSError) # connectivity issues
@@ -48,7 +66,7 @@ def forward(self, x):
4866
run_cls=ExtendedReferenceEvaluator,
4967
atol=1e-5,
5068
rtol=1e-5,
51-
verbose=1,
69+
verbose=10,
5270
),
5371
)
5472
self.assertEqual(len(results), 7)
@@ -83,7 +101,7 @@ def forward(self, x):
83101
run_cls=ExtendedReferenceEvaluator,
84102
atol=1e-5,
85103
rtol=1e-5,
86-
verbose=1,
104+
verbose=10,
87105
),
88106
)
89107
self.assertEqual(len(results), 6)
@@ -115,7 +133,7 @@ def forward(self, x):
115133
run_cls=ExtendedReferenceEvaluator,
116134
atol=1e-5,
117135
rtol=1e-5,
118-
verbose=1,
136+
verbose=10,
119137
),
120138
)
121139
self.assertEqual(len(results), 6)
@@ -285,7 +303,10 @@ def forward(self, x):
285303
),
286304
)
287305
self.assertEqual(len(results), 14)
288-
self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 14)
306+
self.assertEqual(
307+
[r.err_dev for r in results],
308+
[None, None, None, None, None, None, None, None, 0, 0, 0, 0, 0, 0],
309+
)
289310

290311
@hide_stdout()
291312
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -323,7 +344,7 @@ def forward(self, x):
323344
use_tensor=True,
324345
),
325346
)
326-
df = pandas.DataFrame(list(map(post_process_run_aligned_obs, results)))
347+
df = pandas.DataFrame(list(results))
327348
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_custom.xlsx"))
328349
self.assertEqual(
329350
[
@@ -332,6 +353,7 @@ def forward(self, x):
332353
"ep_target",
333354
"err_abs",
334355
"err_dev",
356+
"err_nan",
335357
"err_rel",
336358
"onnx_id_node",
337359
"onnx_name",
@@ -341,7 +363,10 @@ def forward(self, x):
341363
sorted(df.columns),
342364
)
343365
self.assertEqual(len(results), 12)
344-
self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 12)
366+
self.assertEqual(
367+
[r.err_dev for r in results],
368+
[None, None, None, None, None, None, None, None, None, 0, 0, 0],
369+
)
345370
self.assertEqual(
346371
[-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
347372
df["onnx_id_node"].fillna(-10).tolist(),
@@ -384,7 +409,7 @@ def forward(self, x):
384409
use_tensor=True,
385410
),
386411
)
387-
df = pandas.DataFrame(list(map(post_process_run_aligned_obs, results)))
412+
df = pandas.DataFrame(list(results))
388413
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_dynamo.xlsx"))
389414
self.assertEqual(
390415
[
@@ -393,6 +418,7 @@ def forward(self, x):
393418
"ep_target",
394419
"err_abs",
395420
"err_dev",
421+
"err_nan",
396422
"err_rel",
397423
"onnx_id_node",
398424
"onnx_name",
@@ -402,7 +428,10 @@ def forward(self, x):
402428
sorted(df.columns),
403429
)
404430
self.assertEqual(len(results), 12)
405-
self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 12)
431+
self.assertEqual(
432+
[r.err_dev for r in results],
433+
[None, None, None, None, None, None, None, None, None, 0, 0, 0],
434+
)
406435
self.assertEqual(
407436
[-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
408437
df["onnx_id_node"].fillna(-10).tolist(),

onnx_diagnostic/_command_lines_parser.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,9 +1169,9 @@ def get_parser_sbs() -> ArgumentParser:
11691169
parser.add_argument(
11701170
"-r",
11711171
"--ratio",
1172-
default=5,
1172+
default=100,
11731173
required=False,
1174-
help="Saves the result in an excel file every <ratio> node.",
1174+
help="Saves the result in an excel file every <ratio> nodes.",
11751175
)
11761176
return parser
11771177

@@ -1244,10 +1244,14 @@ def _size(name):
12441244
pobs = post_process_run_aligned_obs(obs)
12451245
data.append(pobs)
12461246
if "initializer" not in pobs and "placeholder" not in pobs and len(data) % ratio == 0:
1247-
df = pandas.DataFrame(data)
1247+
df = pandas.DataFrame(data).apply(
1248+
lambda col: col.fillna("") if col.dtype == "object" else col
1249+
)
12481250
df.to_excel(args.output)
12491251
print(f"-- final saves into {args.output!r}")
1250-
df = pandas.DataFrame(data)
1252+
df = pandas.DataFrame(data).apply(
1253+
lambda col: col.fillna("") if col.dtype == "object" else col
1254+
)
12511255
df.to_excel(args.output)
12521256
print("-- done")
12531257

0 commit comments

Comments
 (0)