Skip to content

Commit 2a3a2d4

Browse files
committed
few changes
1 parent 513fc03 commit 2a3a2d4

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

_unittests/ut_tasks/try_export.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
import unittest
44
import onnx
5+
import textwrap
56
import torch
67
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test, ignore_warnings
78
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -47,6 +48,16 @@ def test_qwen25_vli_visual(self):
4748
TESTDTYPE=float16 \\
4849
EXPORTER=custom \\
4950
python _unittests/ut_tasks/try_export.py -k qwen25_vli_visual
51+
52+
.. code-block:: bash
53+
54+
python -m onnx_diagnostic sbs \\
55+
-i qwen25_vli_visual.inputs.pt \\
56+
-e test_qwen25_vli_visual.cuda.float16.PACKED.custom.graph.ep \\
57+
-m test_qwen25_vli_visual.cuda.float16.PACKED.custom.onnx \\
58+
-o test_qwen25_vli_visual.cuda.float16.PACKED.custom.xlsx \\
59+
-v 1 --atol 0.1 --rtol 1000
60+
5061
"""
5162
begin = time.perf_counter()
5263
device = os.environ.get("TESTDEVICE", "cpu")
@@ -174,6 +185,24 @@ def _config_reduction(config, task):
174185
onnx_plugs=PLUGS,
175186
)
176187

188+
with open(
189+
self.get_dump_file(
190+
f"sbs_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.sh"
191+
),
192+
"w",
193+
) as f:
194+
f.write(
195+
textwrap.dedent(
196+
f"""
197+
clear&&python -m onnx_diagnostic sbs \\
198+
-i qwen25_vli_visual.inputs.pt \\
199+
-e test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.graph.ep.pt2 \\
200+
-m test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.onnx \\
201+
-o test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.xlsx \\
202+
-v 1 --atol 0.1 --rtol 1000
203+
"""
204+
)
205+
)
177206
print(f"-- MODEL CONVERTED IN {time.perf_counter() - begin}")
178207
model = onnx.load(filename, load_external_data=False)
179208
if attention == "PACKED":

onnx_diagnostic/_command_lines_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@ def get_parser_sbs() -> ArgumentParser:
12091209
"--ratio",
12101210
default=100,
12111211
required=False,
1212-
help="Saves the result in an excel file every <ratio> nodes.",
1212+
help="Saves the result in an excel file every <ratio> nodes, default is 100.",
12131213
)
12141214
parser.add_argument(
12151215
"--first",
@@ -1247,7 +1247,7 @@ def get_parser_sbs() -> ArgumentParser:
12471247
"--replay-threshold",
12481248
type=float,
12491249
required=False,
1250-
default=1e9,
1250+
default=1e18,
12511251
help="Triggers the replay if the discrepancies are higher than this value.",
12521252
)
12531253
parser.add_argument(

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,10 +513,13 @@ def _preparation_with_onnx_model(
513513
new_name = new_names[0] # type: ignore[assignment, index]
514514
t = torch_results[new_name]
515515
if (
516-
t.shape == tuple(init.dims)
516+
len(set(t.shape)) == len(t.shape) # not repeated dimension
517+
and t.shape == tuple(init.dims)
517518
and torch_dtype_to_onnx_dtype(t.dtype) == init.data_type
518519
):
519520
torch_names_to_onnx_names[new_name] = init.name
521+
else:
522+
t = None
520523

521524
# We should check tensors and proto are the same.
522525
if t is None:

0 commit comments

Comments
 (0)