Skip to content

Commit 300184e

Browse files
authored
Replay onnx nodes with torch results in the side by side (#322)
* replay onnx with torch * mypy * fix * doc * doc
1 parent dda4227 commit 300184e

File tree

6 files changed

+316
-38
lines changed

6 files changed

+316
-38
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7+
* :pr:`322`: support rerunning onnx kernels with torch intermediate results in side-by-side
78
* :pr:`314`: fix modelbuilder download needed after this change https://github.com/microsoft/onnxruntime-genai/pull/1862
89
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
910
* :pr:`310`: splits patches into multiple files

_doc/cmds/sbs.rst

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,40 @@ CPU, CUDA
2020

2121
Inputs are saved :func:`torch.save`. The execution will run on CUDA
2222
if the device of the inputs is CUDA, same goes on CPU.
23+
24+
Example
25+
+++++++
26+
27+
.. code-block::
28+
29+
python -m onnx_diagnostic sbs \
30+
-i qwen_2_5_vl_instruct_visual.inputs.pt \
31+
--ep test_imagetext2text_qwen_2_5_vl_instruct_visual.cuda.float16.custom.graph.ep.pt2 \
32+
-m test_imagetext2text_qwen_2_5_vl_instruct_visual.cuda.float16.custom.onnx \
33+
-o results.dynamo.float16.xlsx \
34+
-v 1 --atol=0.1 --rtol=1 \
35+
--replay-names conv3d,rsqrt,to_4,mul_48,linear,linear_2,linear_84,linear_89,mul_172,linear_156,linear_159 \
36+
-2 --reset conv3d
37+
38+
A snippet of the table it produces:
39+
40+
::
41+
42+
ep_name onnx_name ep_target onnx_op_type onnx_id_output ep_shape_type onnx_shape_type err_abs
43+
transpose_18 transpose_18 aten.transpose.int Transpose 0 GT10s16x1292x80 GT10s16x1292x80 0.0083
44+
unsqueeze_50 unsqueeze_50 aten.unsqueeze.default Unsqueeze 0 GT10s1x16x1292x80 GT10s1x16x1292x80 0.0083
45+
eq_20 eq_20 aten.eq.Scalar Equal 0 GT9s1292x1292 GT9s1292x1292 0
46+
unsqueeze_56 unsqueeze_56 aten.unsqueeze.default Unsqueeze 0 GT9s1x1x1292x1292 GT9s1x1x1292x1292 0
47+
slice_29 slice_29 aten.slice.Tensor Slice 0 GT9s1x1x1292x1292 GT9s1x1x1292x1292 0
48+
transpose_19 transpose_19 aten.transpose.int Transpose 0 GT10s1x1292x16x80 GT10s1x1292x16x80 0.0071
49+
reshape_20 reshape_20 aten.reshape.default Reshape 0 GT10s1292x1280 GT10s1292x1280 0.0071
50+
linear_21 linear_21 aten.linear.default Gemm 0 GT10s1292x1280 GT10s1292x1280 0.0015
51+
mul_54 mul_54 aten.mul.Tensor SkipSimplifiedLayerNormalization 0 GT10s1292x1280 GT10s1292x1280 0.0098
52+
add_32 add_32 aten.add.Tensor SkipSimplifiedLayerNormalization 3 GT10s1292x1280 GT10s1292x1280 0.0313
53+
linear_22 linear_22 aten.linear.default Gemm 0 GT10s1292x3420 GT10s1292x3420 0.0078
54+
silu_4 silu_4 aten.silu.default QuickGelu 0 GT10s1292x3420 GT10s1292x3420 0.0059
55+
56+
The available column are described by
57+
:class:`RunAlignedRecord <onnx_diagnostic.torch_onnx.sbs_dataclasses.RunAlignedRecord>`.
58+
It is possible to dump pieces of the model to study some particular input
59+
with :class:`ReplayConfiguration <onnx_diagnostic.torch_onnx.sbs_dataclasses.ReplayConfiguration>`.

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def forward(self, x):
379379
use_tensor=True,
380380
),
381381
)
382-
df = pandas.DataFrame(list(results))
382+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
383383
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_custom.xlsx"))
384384
self.assertEqual(
385385
[
@@ -390,8 +390,8 @@ def forward(self, x):
390390
"ep_time_run",
391391
"err_abs",
392392
"err_dev",
393+
"err_h001",
393394
"err_h01",
394-
"err_nan",
395395
"err_rel",
396396
"onnx_id_node",
397397
"onnx_id_output",
@@ -445,7 +445,7 @@ def forward(self, x):
445445
use_tensor=True,
446446
),
447447
)
448-
df = pandas.DataFrame(list(results))
448+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
449449
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_dynamo.xlsx"))
450450
self.assertEqual(
451451
[
@@ -456,8 +456,8 @@ def forward(self, x):
456456
"ep_time_run",
457457
"err_abs",
458458
"err_dev",
459+
"err_h001",
459460
"err_h01",
460-
"err_nan",
461461
"err_rel",
462462
"onnx_id_node",
463463
"onnx_id_output",
@@ -542,7 +542,7 @@ def forward(self, x):
542542
reset_names=["linear"],
543543
),
544544
)
545-
df = pandas.DataFrame(list(results))
545+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
546546
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_custom_reset.xlsx"))
547547
onnx_op_type = df["onnx_op_type"].tolist()
548548
self.assertEqual(onnx_op_type.count("reset"), 1)
@@ -593,10 +593,83 @@ def forward(self, x):
593593
),
594594
),
595595
)
596-
df = pandas.DataFrame(list(results))
596+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
597597
df.to_excel(self.get_dump_file("test_sbs_replay.xlsx"))
598-
print(df)
599-
# self.clean_dump()
598+
self.assertEqual(df.shape, (8, 16))
599+
self.clean_dump()
600+
601+
@hide_stdout()
602+
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
603+
def test_sbs_run_onnx_with_torch_inputs(self):
604+
torch = self.torch
605+
606+
class Model(self.torch.nn.Module):
607+
def __init__(self):
608+
super(Model, self).__init__()
609+
self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32
610+
self.relu = torch.nn.ReLU()
611+
self.fc2 = torch.nn.Linear(32, 1) # hidden → output
612+
613+
def forward(self, x):
614+
x = self.relu(self.fc1(x))
615+
x = self.fc2(x)
616+
return x
617+
618+
inputs = dict(x=self.torch.randn((5, 10)))
619+
ds = dict(x={0: "batch"})
620+
Model()(**inputs)
621+
ep = self.torch.export.export(
622+
Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
623+
)
624+
filename = self.get_dump_file("test_sbs_run_onnx_with_torch_inputs.onnx")
625+
to_onnx(ep, exporter="custom", filename=filename)
626+
onx = onnx.load(filename)
627+
results = list(
628+
run_aligned(
629+
ep,
630+
onx,
631+
kwargs=inputs,
632+
run_cls=OnnxruntimeEvaluator,
633+
verbose=11,
634+
use_tensor=True,
635+
run_onnx_with_torch_inputs=True,
636+
),
637+
)
638+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
639+
df.to_excel(self.get_dump_file("test_sbs_run_onnx_with_torch_inputs.xlsx"))
640+
self.assertEqual(
641+
[
642+
"comment",
643+
"ep_id_node",
644+
"ep_name",
645+
"ep_shape_type",
646+
"ep_target",
647+
"ep_time_run",
648+
"err_abs",
649+
"err_abs2",
650+
"err_dev",
651+
"err_dev2",
652+
"err_h001",
653+
"err_h0012",
654+
"err_h01",
655+
"err_h012",
656+
"err_rel",
657+
"err_rel2",
658+
"onnx_id_node",
659+
"onnx_id_output",
660+
"onnx_name",
661+
"onnx_op_type",
662+
"onnx_shape_type",
663+
"onnx_time_run",
664+
],
665+
sorted(df.columns),
666+
)
667+
self.assertEqual(len(results), 8)
668+
self.assertEqual([0, 0, 0, 0, None, 0, 0, 0], [r.err_dev for r in results])
669+
self.assertEqual(
670+
[-1, -1, -1, -1, -1, 0, 1, 2], df["onnx_id_node"].fillna(-10).tolist()
671+
)
672+
self.clean_dump()
600673

601674

602675
if __name__ == "__main__":

onnx_diagnostic/_command_lines_parser.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,19 @@ def get_parser_sbs() -> ArgumentParser:
12171217
default=False,
12181218
help="First runs the whole model.",
12191219
)
1220+
parser.add_argument(
1221+
"-2",
1222+
"--second-run",
1223+
action=BooleanOptionalAction,
1224+
default=False,
1225+
help=textwrap.dedent(
1226+
"""
1227+
Tries to run all onnx nodes with torch results produced by the exported
1228+
program. It then measures the discrepancies again. It can be used
1229+
to identify kernel introduces discrepancies from other just propagating them.
1230+
"""
1231+
),
1232+
)
12201233
parser.add_argument(
12211234
"--reset",
12221235
required=False,
@@ -1365,6 +1378,7 @@ def _size(name):
13651378
reset_names=args.reset.split(","),
13661379
exc=False,
13671380
replay_configuration=replay_configuration,
1381+
run_onnx_with_torch_inputs=args.second_run,
13681382
):
13691383
data.append(obs)
13701384
if (
@@ -1377,8 +1391,10 @@ def _size(name):
13771391
)
13781392
df.to_excel(args.output)
13791393
print(f"-- final saves into {args.output!r}")
1380-
df = pandas.DataFrame(data).apply(
1381-
lambda col: col.fillna("") if col.dtype == "object" else col
1394+
df = (
1395+
pandas.DataFrame(data)
1396+
.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
1397+
.dropna(axis=1, how="all")
13821398
)
13831399
df.to_excel(args.output, index=False)
13841400
print("-- done")

0 commit comments

Comments
 (0)