Skip to content

Commit e7399a1

Browse files
committed
fix
1 parent f8bca95 commit e7399a1

File tree

5 files changed

+67
-27
lines changed

5 files changed

+67
-27
lines changed

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import onnx
23
from onnx_diagnostic.ext_test_case import (
34
ExtTestCase,
45
hide_stdout,
@@ -9,11 +10,7 @@
910
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
1011
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1112
from onnx_diagnostic.torch_onnx.sbs import run_aligned
12-
13-
try:
14-
from experimental_experiment.torch_interpreter import to_onnx
15-
except ImportError:
16-
to_onnx = None
13+
from onnx_diagnostic.export.api import to_onnx
1714

1815

1916
class TestSideBySide(ExtTestCase):
@@ -41,7 +38,7 @@ def forward(self, x):
4138
ep = self.torch.export.export(
4239
Model(), (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},)
4340
)
44-
onx = to_onnx(ep)
41+
onx = to_onnx(ep, exporter="custom").model_proto
4542
results = list(
4643
run_aligned(
4744
ep,
@@ -71,10 +68,12 @@ def forward(self, x):
7168
ep = self.torch.export.export(
7269
Model(), (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},)
7370
)
74-
epo = self.torch.onnx.export(
75-
ep, (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},), dynamo=True
76-
)
77-
onx = epo.model_proto
71+
onx = to_onnx(
72+
ep,
73+
(x,),
74+
dynamic_shapes=({0: self.torch.export.Dim("batch")},),
75+
exporter="onnx-dynamo",
76+
).model_proto
7877
results = list(
7978
run_aligned(
8079
ep,
@@ -105,9 +104,7 @@ def forward(self, x):
105104
ep = self.torch.export.export(
106105
Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
107106
)
108-
epo = self.torch.onnx.export(
109-
Model(), (), kwargs=inputs, dynamic_shapes=ds, dynamo=True
110-
)
107+
epo = to_onnx(Model(), (), kwargs=inputs, dynamic_shapes=ds, exporter="onnx-dynamo")
111108
onx = epo.model_proto
112109
results = list(
113110
run_aligned(
@@ -139,7 +136,7 @@ def forward(self, x):
139136
ep = self.torch.export.export(
140137
Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
141138
)
142-
onx = to_onnx(ep)
139+
onx = to_onnx(ep, exporter="custom").model_proto
143140
results = list(
144141
run_aligned(
145142
ep,
@@ -170,7 +167,7 @@ def forward(self, x):
170167
ep = self.torch.export.export(
171168
Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
172169
)
173-
onx = to_onnx(ep)
170+
onx = to_onnx(ep, exporter="custom").model_proto
174171
results = list(
175172
run_aligned(
176173
ep,
@@ -204,7 +201,7 @@ def forward(self, x):
204201
ep = self.torch.export.export(
205202
Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
206203
)
207-
onx = to_onnx(ep)
204+
onx = to_onnx(ep, exporter="custom").model_proto
208205
results = list(
209206
run_aligned(
210207
ep,
@@ -240,7 +237,7 @@ def forward(self, x):
240237
ep = self.torch.export.export(
241238
Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
242239
)
243-
onx = to_onnx(ep)
240+
onx = to_onnx(ep, exporter="custom").model_proto
244241
results = list(
245242
run_aligned(
246243
ep,
@@ -275,7 +272,7 @@ def forward(self, x):
275272
ep = self.torch.export.export(
276273
Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
277274
)
278-
onx = to_onnx(ep)
275+
onx = to_onnx(ep, exporter="custom").model_proto
279276
results = list(
280277
run_aligned(
281278
ep,
@@ -291,6 +288,45 @@ def forward(self, x):
291288
self.assertEqual(len(results), 7)
292289
self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0])
293290

291+
@hide_stdout()
292+
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
293+
def test_sbs_model_with_weights(self):
294+
torch = self.torch
295+
296+
class Model(self.torch.nn.Module):
297+
def __init__(self):
298+
super(Model, self).__init__()
299+
self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32
300+
self.relu = torch.nn.ReLU()
301+
self.fc2 = torch.nn.Linear(32, 1) # hidden → output
302+
303+
def forward(self, x):
304+
x = self.relu(self.fc1(x))
305+
x = self.fc2(x)
306+
return x
307+
308+
inputs = dict(x=self.torch.randn((5, 10)))
309+
ds = dict(x={0: "batch"})
310+
Model()(**inputs)
311+
ep = self.torch.export.export(
312+
Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
313+
)
314+
filename = self.get_dump_file("test_sbs_model_with_weights.onnx")
315+
to_onnx(ep, exporter="custom", filename=filename)
316+
onx = onnx.load(filename)
317+
results = list(
318+
run_aligned(
319+
ep,
320+
onx,
321+
kwargs=inputs,
322+
run_cls=OnnxruntimeEvaluator,
323+
verbose=11,
324+
use_tensor=True,
325+
),
326+
)
327+
self.assertEqual(len(results), 7)
328+
self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0])
329+
294330

295331
if __name__ == "__main__":
296332
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ def post_process(obs):
12301230
for obs in run_aligned(
12311231
ep,
12321232
onx,
1233-
run_cls=OnnxruntimeEvaluator,
1233+
run_cls=OnnxruntimeEvaluator, # type: ignore[arg-type]
12341234
atol=float(args.atol),
12351235
rtol=float(args.rtol),
12361236
verbose=int(args.verbose),

onnx_diagnostic/helpers/helper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,7 @@ def max_diff(
12071207
if "dev" in d:
12081208
if dd is None:
12091209
dd = d["dev"]
1210-
else:
1210+
elif d["dev"] is not None:
12111211
dd += d["dev"]
12121212

12131213
res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn)
@@ -1264,7 +1264,7 @@ def max_diff(
12641264
# out of boundary
12651265
res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
12661266
if dev:
1267-
res[dev] = dev
1267+
res["dev"] = dev
12681268
return res
12691269
if isinstance(expected, (int, float)):
12701270
if isinstance(got, np.ndarray) and len(got.shape) == 0:
@@ -1280,7 +1280,7 @@ def max_diff(
12801280
dnan=0,
12811281
)
12821282
if dev:
1283-
res[dev] = dev
1283+
res["dev"] = dev
12841284
return res
12851285
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
12861286
if expected.dtype in (np.complex64, np.complex128):
@@ -1362,7 +1362,7 @@ def max_diff(
13621362
abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
13631363
)
13641364
if dev:
1365-
res[dev] = dev
1365+
res["dev"] = dev
13661366
if hist:
13671367
if isinstance(hist, bool):
13681368
hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,13 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L
553553
feeds = dict(zip(node.input, inputs))
554554
if "" in feeds:
555555
cls = None
556-
for k, v in feeds:
556+
for k, v in feeds.items():
557557
if k != "":
558558
cls = v.__class__
559+
break
560+
assert (
561+
cls is not None
562+
), f"Unable to get input class (array or tensor), feeds={string_type(feeds)}"
559563
feeds[""] = cls([0])
560564

561565
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,9 @@ def _loop_cmp(
537537
print(f"[run_aligned] onnx {len(onnx_results)} constants")
538538
print(f"[run_aligned] common {len(mapping_onnx_to_torch)} constants")
539539
for k, v in torch_results.items():
540-
print(f"[run_aligned-ep] +cst: {k}: {string_type(v, str_kws)}")
540+
print(f"[run_aligned-ep] +cst: {k}: {string_type(v, **str_kws)}")
541541
for k, v in onnx_results.items():
542-
print(f"[run_aligned-nx] +ini: {k}: {string_type(v, str_kws)}")
542+
print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}")
543543

544544
onnx_args = list(args) if args else []
545545
if kwargs:
@@ -578,7 +578,7 @@ def _loop_cmp(
578578
if exc:
579579
raise AssertionError(
580580
f"unable to process node {node.op} -> {node.name!r}, "
581-
f"possible candiate are "
581+
f"possible candidate are "
582582
f"{sorted(p for p in onnx_results if node.name in p)}, "
583583
f"not in {sorted(onnx_results)}, "
584584
f"args={string_type(args, **str_kws)}, "

0 commit comments

Comments
 (0)