Skip to content

Commit 1346aa9

Browse files
committed
fix ut
1 parent b4916a8 commit 1346aa9

File tree

3 files changed

+74
-34
lines changed

3 files changed

+74
-34
lines changed

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ def test_run_aligned_record(self):
2929
onnx_name="B",
3030
ep_target="C",
3131
onnx_op_type="D",
32-
shape_type="E",
32+
ep_shape_type="E",
3333
err_abs=0.1,
3434
err_rel=0.2,
3535
err_dev=0.3,
3636
err_nan=0.4,
3737
)
3838
sr = str(r)
3939
self.assertIn("RunAlignedRecord(", sr)
40-
self.assertIn("shape_type='E'", sr)
40+
self.assertIn("ep_shape_type='E'", sr)
4141

4242
@hide_stdout()
4343
@unittest.skipIf(to_onnx is None, "to_onnx not installed")
@@ -69,7 +69,7 @@ def forward(self, x):
6969
verbose=10,
7070
),
7171
)
72-
self.assertEqual(len(results), 6)
72+
self.assertEqual(len(results), 7)
7373

7474
@hide_stdout()
7575
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -104,7 +104,7 @@ def forward(self, x):
104104
verbose=10,
105105
),
106106
)
107-
self.assertEqual(len(results), 5)
107+
self.assertEqual(len(results), 6)
108108

109109
@hide_stdout()
110110
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -136,7 +136,7 @@ def forward(self, x):
136136
verbose=10,
137137
),
138138
)
139-
self.assertEqual(len(results), 5)
139+
self.assertEqual(len(results), 6)
140140

141141
@hide_stdout()
142142
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -167,7 +167,7 @@ def forward(self, x):
167167
verbose=11,
168168
),
169169
)
170-
self.assertEqual(len(results), 6)
170+
self.assertEqual(len(results), 7)
171171

172172
@hide_stdout()
173173
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -199,7 +199,7 @@ def forward(self, x):
199199
use_tensor=True,
200200
),
201201
)
202-
self.assertEqual(len(results), 7)
202+
self.assertEqual(len(results), 8)
203203

204204
@hide_stdout()
205205
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -232,7 +232,7 @@ def forward(self, x):
232232
use_tensor=True,
233233
),
234234
)
235-
self.assertEqual(len(results), 7)
235+
self.assertEqual(len(results), 8)
236236

237237
@hide_stdout()
238238
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -267,7 +267,7 @@ def forward(self, x):
267267
use_tensor=True,
268268
),
269269
)
270-
self.assertEqual(len(results), 8)
270+
self.assertEqual(len(results), 14)
271271

272272
@hide_stdout()
273273
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -301,9 +301,9 @@ def forward(self, x):
301301
use_tensor=True,
302302
),
303303
)
304-
self.assertEqual(len(results), 8)
304+
self.assertEqual(len(results), 14)
305305
self.assertEqual(
306-
[None, None, 0, 0, 0, 0, 0, 0],
306+
[None, None, None, None, None, None, None, None, 0, 0, 0, 0, 0, 0],
307307
[r.err_dev for r in results],
308308
)
309309

@@ -349,29 +349,27 @@ def forward(self, x):
349349
[
350350
"ep_id_node",
351351
"ep_name",
352+
"ep_shape_type",
352353
"ep_target",
353354
"ep_time_run",
354355
"err_abs",
355356
"err_dev",
357+
"err_h01",
356358
"err_nan",
357359
"err_rel",
358360
"onnx_id_node",
359361
"onnx_id_output",
360362
"onnx_name",
361363
"onnx_op_type",
364+
"onnx_shape_type",
362365
"onnx_time_run",
363-
"shape_type",
364366
],
365367
sorted(df.columns),
366368
)
367369
self.assertEqual(len(results), 8)
370+
self.assertEqual([0, 0, 0, 0, None, 0, 0, 0], [r.err_dev for r in results])
368371
self.assertEqual(
369-
[None, None, None, None, None, 0, 0, 0],
370-
[r.err_dev for r in results],
371-
)
372-
self.assertEqual(
373-
[-10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
374-
df["onnx_id_node"].fillna(-10).tolist(),
372+
[-1, -1, -1, -1, -1, 0, 1, 2], df["onnx_id_node"].fillna(-10).tolist()
375373
)
376374
self.clean_dump()
377375

@@ -417,29 +415,27 @@ def forward(self, x):
417415
[
418416
"ep_id_node",
419417
"ep_name",
418+
"ep_shape_type",
420419
"ep_target",
421420
"ep_time_run",
422421
"err_abs",
423422
"err_dev",
423+
"err_h01",
424424
"err_nan",
425425
"err_rel",
426426
"onnx_id_node",
427427
"onnx_id_output",
428428
"onnx_name",
429429
"onnx_op_type",
430+
"onnx_shape_type",
430431
"onnx_time_run",
431-
"shape_type",
432432
],
433433
sorted(df.columns),
434434
)
435435
self.assertEqual(len(results), 8)
436+
self.assertEqual([0, 0, 0, 0, None, 0, 0, 0], [r.err_dev for r in results])
436437
self.assertEqual(
437-
[None, None, None, None, None, 0, 0, 0],
438-
[r.err_dev for r in results],
439-
)
440-
self.assertEqual(
441-
[-10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
442-
df["onnx_id_node"].fillna(-10).tolist(),
438+
[-1, -1, -1, -1, -1, 0, 1, 2], df["onnx_id_node"].fillna(-10).tolist()
443439
)
444440
self.clean_dump()
445441

@@ -466,7 +462,7 @@ def forward(self, x):
466462
use_tensor=True,
467463
),
468464
)
469-
self.assertEqual(len(results), 2)
465+
self.assertEqual(len(results), 5)
470466

471467

472468
if __name__ == "__main__":

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def forward(self, x):
128128
[
129129
"sbs",
130130
"-v",
131-
"1",
131+
"2",
132132
"--first",
133133
"-i",
134134
input_file,
@@ -151,9 +151,10 @@ def forward(self, x):
151151
self.assertIn("p_fc1_weight", set(df["ep_name"]))
152152
self.assertIn("fc1.bias", set(df["onnx_name"]))
153153
self.assertNotIn("NaN", set(df["ep_name"]))
154-
print(df)
155-
print(st.getvalue())
154+
# print(f"{df}\n{st.getvalue()}")
156155
self.assertIn("[run_aligned] done", st.getvalue())
156+
sdf = df[(df.ep_target == "placeholder") & (df.onnx_op_type == "initializer")]
157+
self.assertEqual(sdf.shape[0], 4)
157158

158159

159160
if __name__ == "__main__":

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from ..helpers import string_type, string_diff, max_diff, flatten_object
1010
from ..helpers.onnx_helper import pretty_onnx
11-
from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor
11+
from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype
1212

1313

1414
def validate_fx_tensor(
@@ -753,6 +753,39 @@ def _gemm_linear(node, feeds, sess):
753753
if verbose:
754754
print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}")
755755

756+
# alias for initializers
757+
skip_onnx_name = set()
758+
init_aliases = {}
759+
for init in onx.graph.initializer:
760+
new_names = {
761+
n
762+
for n in [
763+
f"p_{init.name.replace('.', '_')}",
764+
f"p_{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
765+
f"{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
766+
]
767+
if n != init.name
768+
}
769+
drop = False
770+
for new_name in new_names:
771+
if new_name in skip_onnx_name:
772+
drop = True
773+
break
774+
if drop:
775+
skip_onnx_name |= new_names | {init.name}
776+
for new_name in new_names:
777+
if new_names in init_aliases:
778+
del init_aliases[new_name]
779+
else:
780+
for new_name in new_names:
781+
init_aliases[new_name] = init.name
782+
rev_init_aliases = {}
783+
for k, v in init_aliases.items():
784+
if v in rev_init_aliases:
785+
rev_init_aliases[v].add(k)
786+
else:
787+
rev_init_aliases[v] = {k}
788+
756789
# initializers
757790
if verbose:
758791
print(f"[run_aligned] nx: handles {len(onx.graph.initializer)} initializers from onnx")
@@ -765,11 +798,21 @@ def _gemm_linear(node, feeds, sess):
765798
if init.name not in skip_mapping_torch_onnx:
766799
t = torch_results[init.name]
767800
torch_names_to_onnx_names[init.name] = init.name
768-
else:
769-
new_name = f"p_{init.name.replace('.', '_')}"
770-
if new_name not in skip_mapping_torch_onnx and new_name in torch_results:
801+
elif init.name not in skip_onnx_name and init.name in rev_init_aliases:
802+
new_names = [
803+
k
804+
for k in rev_init_aliases[init.name]
805+
if k in torch_results and k not in skip_mapping_torch_onnx
806+
]
807+
if new_names and len(new_names) == 1:
808+
new_name = new_names[0]
771809
t = torch_results[new_name]
772-
torch_names_to_onnx_names[new_name] = init.name
810+
if (
811+
t.shape == tuple(init.dims)
812+
and torch_dtype_to_onnx_dtype(t.dtype) == init.data_type
813+
):
814+
torch_names_to_onnx_names[new_name] = init.name
815+
773816
# We should check tensors and proto are the same.
774817
if t is None:
775818
t = to_tensor(init)

0 commit comments

Comments
 (0)