Skip to content

Commit 6de5c5c

Browse files
committed
replay onnx with torch
1 parent dda4227 commit 6de5c5c

File tree

4 files changed

+221
-35
lines changed

4 files changed

+221
-35
lines changed

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def forward(self, x):
379379
use_tensor=True,
380380
),
381381
)
382-
df = pandas.DataFrame(list(results))
382+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
383383
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_custom.xlsx"))
384384
self.assertEqual(
385385
[
@@ -391,7 +391,6 @@ def forward(self, x):
391391
"err_abs",
392392
"err_dev",
393393
"err_h01",
394-
"err_nan",
395394
"err_rel",
396395
"onnx_id_node",
397396
"onnx_id_output",
@@ -445,7 +444,7 @@ def forward(self, x):
445444
use_tensor=True,
446445
),
447446
)
448-
df = pandas.DataFrame(list(results))
447+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
449448
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_dynamo.xlsx"))
450449
self.assertEqual(
451450
[
@@ -457,7 +456,6 @@ def forward(self, x):
457456
"err_abs",
458457
"err_dev",
459458
"err_h01",
460-
"err_nan",
461459
"err_rel",
462460
"onnx_id_node",
463461
"onnx_id_output",
@@ -542,7 +540,7 @@ def forward(self, x):
542540
reset_names=["linear"],
543541
),
544542
)
545-
df = pandas.DataFrame(list(results))
543+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
546544
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_custom_reset.xlsx"))
547545
onnx_op_type = df["onnx_op_type"].tolist()
548546
self.assertEqual(onnx_op_type.count("reset"), 1)
@@ -593,10 +591,80 @@ def forward(self, x):
593591
),
594592
),
595593
)
596-
df = pandas.DataFrame(list(results))
594+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
597595
df.to_excel(self.get_dump_file("test_sbs_replay.xlsx"))
598-
print(df)
599-
# self.clean_dump()
596+
self.assertEqual(df.shape, (8, 15))
597+
self.clean_dump()
598+
599+
@hide_stdout()
600+
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
601+
def test_sbs_run_onnx_with_torch_inputs(self):
602+
torch = self.torch
603+
604+
class Model(self.torch.nn.Module):
605+
def __init__(self):
606+
super(Model, self).__init__()
607+
self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32
608+
self.relu = torch.nn.ReLU()
609+
self.fc2 = torch.nn.Linear(32, 1) # hidden → output
610+
611+
def forward(self, x):
612+
x = self.relu(self.fc1(x))
613+
x = self.fc2(x)
614+
return x
615+
616+
inputs = dict(x=self.torch.randn((5, 10)))
617+
ds = dict(x={0: "batch"})
618+
Model()(**inputs)
619+
ep = self.torch.export.export(
620+
Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
621+
)
622+
filename = self.get_dump_file("test_sbs_run_onnx_with_torch_inputs.onnx")
623+
to_onnx(ep, exporter="custom", filename=filename)
624+
onx = onnx.load(filename)
625+
results = list(
626+
run_aligned(
627+
ep,
628+
onx,
629+
kwargs=inputs,
630+
run_cls=OnnxruntimeEvaluator,
631+
verbose=11,
632+
use_tensor=True,
633+
run_onnx_with_torch_inputs=True,
634+
),
635+
)
636+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
637+
df.to_excel(self.get_dump_file("test_sbs_run_onnx_with_torch_inputs.xlsx"))
638+
self.assertEqual(
639+
[
640+
"ep_id_node",
641+
"ep_name",
642+
"ep_shape_type",
643+
"ep_target",
644+
"ep_time_run",
645+
"err_abs",
646+
"err_abs2",
647+
"err_dev",
648+
"err_dev2",
649+
"err_h01",
650+
"err_h012",
651+
"err_rel",
652+
"err_rel2",
653+
"onnx_id_node",
654+
"onnx_id_output",
655+
"onnx_name",
656+
"onnx_op_type",
657+
"onnx_shape_type",
658+
"onnx_time_run",
659+
],
660+
sorted(df.columns),
661+
)
662+
self.assertEqual(len(results), 8)
663+
self.assertEqual([0, 0, 0, 0, None, 0, 0, 0], [r.err_dev for r in results])
664+
self.assertEqual(
665+
[-1, -1, -1, -1, -1, 0, 1, 2], df["onnx_id_node"].fillna(-10).tolist()
666+
)
667+
self.clean_dump()
600668

601669

602670
if __name__ == "__main__":

onnx_diagnostic/_command_lines_parser.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,19 @@ def get_parser_sbs() -> ArgumentParser:
12171217
default=False,
12181218
help="First runs the whole model.",
12191219
)
1220+
parser.add_argument(
1221+
"-2",
1222+
"--second-run",
1223+
action=BooleanOptionalAction,
1224+
default=False,
1225+
help=textwrap.dedent(
1226+
"""
1227+
Tries to run all onnx nodes with torch results produced by the exported
1228+
program. It then measures the discrepancies again. It can be used
1229+
to identify kernel introduces discrepancies from other just propagating them.
1230+
"""
1231+
),
1232+
)
12201233
parser.add_argument(
12211234
"--reset",
12221235
required=False,
@@ -1365,6 +1378,7 @@ def _size(name):
13651378
reset_names=args.reset.split(","),
13661379
exc=False,
13671380
replay_configuration=replay_configuration,
1381+
run_onnx_with_torch_inputs=args.second_run,
13681382
):
13691383
data.append(obs)
13701384
if (
@@ -1377,8 +1391,10 @@ def _size(name):
13771391
)
13781392
df.to_excel(args.output)
13791393
print(f"-- final saves into {args.output!r}")
1380-
df = pandas.DataFrame(data).apply(
1381-
lambda col: col.fillna("") if col.dtype == "object" else col
1394+
df = (
1395+
pandas.DataFrame(data)
1396+
.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
1397+
.dropna(axis=1, how="all")
13821398
)
13831399
df.to_excel(args.output, index=False)
13841400
print("-- done")

0 commit comments

Comments
 (0)