Skip to content

Commit 9c9bf00

Browse files
authored
Adds discrepancies with the exporter program (#307)
* Adds discrepancies with the exporter program * better
1 parent 0521c46 commit 9c9bf00

File tree

6 files changed

+114
-19
lines changed

6 files changed

+114
-19
lines changed

_unittests/ut_tasks/try_export.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,15 @@ def _config_reduction(config, task):
115115
verbose=1,
116116
stop_if_static=2,
117117
):
118+
if exporter == "onnx-dynamo":
119+
# The exported program in ONNXProgram cannot be restored.
120+
ep2 = torch.export.export(
121+
model.visual,
122+
(),
123+
kwargs=export_inputs,
124+
dynamic_shapes=self.use_dyn_not_str(dynamic_shapes),
125+
)
126+
torch.export.save(ep2, f"{fileep}.backup.pt2")
118127
to_onnx(
119128
model.visual,
120129
kwargs=export_inputs,
@@ -127,6 +136,14 @@ def _config_reduction(config, task):
127136
optimize=True,
128137
)
129138

139+
pt2_files = [f"{fileep}.backup.pt2", f"{fileep}.ep.pt2", f"{fileep}.pt2"]
140+
pt2_file = [f for f in pt2_files if os.path.exists(f)]
141+
assert pt2_file, f"Unable to find an existing file among {pt2_files}"
142+
pt2_file = pt2_file[0]
143+
# self.assertExists(pt2_file)
144+
# ep = torch.export.load(pt2_file)
145+
# diff = self.max_diff(ep.module()(**export_inputs), model.visual(**export_inputs))
146+
# print("----------- diff", diff)
130147
self.assert_onnx_disc(
131148
f"test_imagetext2text_qwen_2_5_vl_instruct_visual.{device}.{dtype}.{exporter}",
132149
filename,
@@ -142,6 +159,7 @@ def _config_reduction(config, task):
142159
atol=0.02,
143160
rtol=10,
144161
ort_optimized_graph=False,
162+
ep=pt2_file,
145163
)
146164

147165

onnx_diagnostic/export/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def to_onnx(
112112
ort_fusions.optimize_for_ort(epo.model)
113113
if filename:
114114
epo.save(filename, external_data=True)
115+
if save_ep:
116+
if isinstance(save_ep, tuple):
117+
save_ep = save_ep[0]
118+
torch.export.save(epo.exported_program, f"{save_ep}.pt2")
115119
return epo
116120

117121
if exporter == "modelbuilder":

onnx_diagnostic/ext_test_case.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,7 @@ def assert_onnx_disc(
11991199
expected: Optional[Any] = None,
12001200
use_ort: bool = False,
12011201
ort_optimized_graph: bool = False,
1202+
ep: Optional[Union["torch.export.ExportedProgram", str]] = None, # noqa: F821
12021203
**kwargs,
12031204
):
12041205
"""
@@ -1218,6 +1219,7 @@ def assert_onnx_disc(
12181219
:param copy_inputs: to copy the inputs
12191220
:param use_ort: use :class:`onnxruntime.InferenceSession`
12201221
:param ort_optimized_graph: dumps the optimized onnxruntime graph
1222+
:param ep: exported program (or saved exported program)
12211223
:param kwargs: arguments sent to
12221224
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
12231225
"""
@@ -1245,6 +1247,7 @@ def assert_onnx_disc(
12451247
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
12461248
if verbose:
12471249
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
1250+
12481251
if use_ort:
12491252
assert isinstance(
12501253
proto, onnx.ModelProto
@@ -1275,6 +1278,7 @@ def assert_onnx_disc(
12751278
got = sess.run(None, feeds)
12761279
if verbose:
12771280
print(f"[{vname}] compute expected values")
1281+
12781282
if expected is None:
12791283
if copy_inputs:
12801284
expected = (
@@ -1284,9 +1288,45 @@ def assert_onnx_disc(
12841288
)
12851289
else:
12861290
expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
1291+
12871292
if verbose:
12881293
print(f"[{vname}] expected {string_type(expected, **kws)}")
12891294
print(f"[{vname}] obtained {string_type(got, **kws)}")
1295+
1296+
if ep:
1297+
if isinstance(ep, str):
1298+
if verbose:
1299+
print(f"[{vname}] load exported program {ep!r}")
1300+
import torch
1301+
1302+
ep = torch.export.load(ep)
1303+
ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
1304+
ep_model = ep.module() # type: ignore[union-attr]
1305+
ep_expected = (
1306+
ep_model(*copy.deepcopy(ep_inputs))
1307+
if isinstance(ep_inputs, tuple)
1308+
else ep_model(**copy.deepcopy(ep_inputs))
1309+
)
1310+
if verbose:
1311+
print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
1312+
ep_diff = max_diff(expected, ep_expected)
1313+
if verbose:
1314+
print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
1315+
assert (
1316+
isinstance(ep_diff["abs"], float)
1317+
and isinstance(ep_diff["rel"], float)
1318+
and not numpy.isnan(ep_diff["abs"])
1319+
and ep_diff["abs"] <= atol
1320+
and not numpy.isnan(ep_diff["rel"])
1321+
and ep_diff["rel"] <= rtol
1322+
), (
1323+
f"discrepancies in {test_name!r} between the model "
1324+
f"and the exported model diff={string_diff(ep_diff)}"
1325+
)
1326+
ep_nx_diff = max_diff(ep_expected, got, flatten=True)
1327+
if verbose:
1328+
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
1329+
12901330
diff = max_diff(expected, got, flatten=True)
12911331
if verbose:
12921332
print(f"[{vname}] diff {string_diff(diff)}")
@@ -1297,7 +1337,10 @@ def assert_onnx_disc(
12971337
and diff["abs"] <= atol
12981338
and not numpy.isnan(diff["rel"])
12991339
and diff["rel"] <= rtol
1300-
), f"discrepancies in {test_name!r}, diff={string_diff(diff)}"
1340+
), (
1341+
f"discrepancies in {test_name!r} between the model and "
1342+
f"the onnx model diff={string_diff(diff)}"
1343+
)
13011344

13021345
def _debug(self):
13031346
"Tells if DEBUG=1 is set up."
@@ -1308,6 +1351,16 @@ def string_type(self, *args, **kwargs):
13081351

13091352
return string_type(*args, **kwargs)
13101353

1354+
def max_diff(self, *args, **kwargs):
1355+
from .helpers import max_diff
1356+
1357+
return max_diff(*args, **kwargs)
1358+
1359+
def use_dyn_not_str(self, *args, **kwargs):
1360+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1361+
1362+
return use_dyn_not_str(*args, *kwargs)
1363+
13111364
def subloop(self, *args, verbose: int = 0):
13121365
"Loops over elements and calls :meth:`unittests.TestCase.subTest`."
13131366
if len(args) == 1:

onnx_diagnostic/helpers/ort_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,13 @@ def __init__(
134134

135135
self.sess = sess
136136
self.input_names = [i.name for i in sess.get_inputs()]
137+
assert (
138+
"" not in self.input_names
139+
), f"Input name cannot be empty but input_names={self.input_names}"
137140
self.output_names = [i.name for i in sess.get_outputs()]
141+
assert (
142+
"" not in self.input_names
143+
), f"Output name cannot be empty but output_names={self.output_names}"
138144
self.input_shapes = [i.shape for i in sess.get_inputs()]
139145
self.output_shapes = [i.shape for i in sess.get_outputs()]
140146
self.input_types = [i.type for i in sess.get_inputs()]
@@ -497,6 +503,7 @@ def run_dlpack(
497503
values = ORTC.OrtValueVector()
498504
device = -1
499505
for k, v in feeds.items():
506+
assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
500507
device = max(device, v.get_device())
501508
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
502509
if not v.is_contiguous():

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -564,18 +564,14 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L
564564
onx, sess = self._get_sess(node, inputs)
565565
self._cache[key] = onx, sess
566566

567-
feeds = dict(zip(node.input, inputs))
568-
if "" in feeds:
569-
cls = None
570-
for k, v in feeds.items():
571-
if k != "":
572-
cls = v.__class__
573-
break
574-
assert (
575-
cls is not None
576-
), f"Unable to get input class (array or tensor), feeds={string_type(feeds)}"
577-
feeds[""] = cls([0])
578-
567+
feeds = {}
568+
for i, val in zip(node.input, inputs):
569+
if i == "":
570+
assert (
571+
val is None
572+
), f"input name={i!r} but val={string_type(val, with_shape=True)}"
573+
continue
574+
feeds[i] = val
579575
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
580576
outputs = list(sess.run(None, feeds))
581577
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -567,10 +567,11 @@ def _loop_cmp(
567567
print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}")
568568

569569
placeholders = {node.name for node in ep.graph.nodes if node.op == "placeholder"}
570-
ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers())}
570+
ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)}
571571
placeholders_to_state_dict = {
572572
**{f"p_{name.replace('.', '_')}": name for name in ep.state_dict},
573573
**{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()},
574+
**{f"c_{name.replace('.', '_')}": name for name in ep.tensor_constants},
574575
}
575576
for n in onnx_results:
576577
if n not in placeholders:
@@ -588,6 +589,7 @@ def _loop_cmp(
588589
else:
589590
loop = list(enumerate(ep_graph_nodes))
590591

592+
already_run = set()
591593
ep_durations = {}
592594
yielded_nodes = 0
593595
max_abs = 0
@@ -641,8 +643,8 @@ def _loop_cmp(
641643
yield record
642644
else:
643645
assert node.name in placeholders_to_state_dict, (
644-
f"Unable to find placeholder {node.name!r} in "
645-
f"{sorted(placeholders_to_state_dict)}"
646+
f"Unable to find placeholder {node.name!r} (node.op={node.op!r}), "
647+
f"existing: {sorted(placeholders_to_state_dict)}"
646648
)
647649
torch_results[node.name] = ep_state_dict[placeholders_to_state_dict[node.name]]
648650
if verbose > 1:
@@ -683,6 +685,8 @@ def _loop_cmp(
683685
continue
684686

685687
for i_onnx in range(last_position, max_pos + 1):
688+
if i_onnx in already_run:
689+
continue
686690
node = onx.graph.node[i_onnx]
687691
if verbose > 1:
688692
print(
@@ -695,9 +699,16 @@ def _loop_cmp(
695699
f"mapped {yielded_nodes} maxabs {max_abs:1.5f}"
696700
)
697701
ref = run_cls(node, **run_cls_kwargs)
698-
feeds = {k: onnx_results[k] for k in node.input}
702+
feeds = {k: onnx_results[k] for k in node.input if k}
703+
assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}"
699704
begin = time.perf_counter()
700-
res = ref.run(None, feeds) # type: ignore[attr-defined]
705+
try:
706+
res = ref.run(None, feeds) # type: ignore[attr-defined]
707+
except Exception as e:
708+
raise RuntimeError(
709+
f"Unable to run node {node.op_type}, domain={node.domain} "
710+
f"with inputs={node.input}, feeds={string_type(feeds, **str_kws)}"
711+
) from e
701712
duration = time.perf_counter() - begin
702713
assert (
703714
not has_cuda
@@ -748,6 +759,7 @@ def _loop_cmp(
748759
if tmp.err_abs is not None:
749760
max_abs = max(max_abs, tmp.err_abs)
750761
yield tmp
762+
already_run.add(i_onnx)
751763

752764
last_position = max_pos + 1
753765

@@ -758,14 +770,17 @@ def _loop_cmp(
758770
f"to {len(onx.graph.node)}"
759771
)
760772
for i_onnx in range(last_position, len(onx.graph.node)):
773+
if i_onnx in already_run:
774+
continue
761775
node = onx.graph.node[i_onnx]
762776
if verbose > 1:
763777
print(
764778
f"[run_aligned] run onx.graph.node[{i_onnx}]: "
765779
f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
766780
)
767781
ref = run_cls(node, **run_cls_kwargs)
768-
feeds = {k: onnx_results[k] for k in node.input}
782+
feeds = {k: onnx_results[k] for k in node.input if k}
783+
assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}"
769784
begin = time.perf_counter()
770785
res = ref.run(None, feeds) # type: ignore[attr-defined]
771786
duration = time.perf_counter() - begin
@@ -800,6 +815,8 @@ def _loop_cmp(
800815
if tmp.err_abs is not None:
801816
max_abs = max(max_abs, tmp.err_abs)
802817
yield tmp
818+
already_run.add(i_onnx)
819+
803820
if verbose:
804821
print(f"[run_aligned] done with {yielded_nodes} mapped nodes")
805822
print(f"[run_aligned] max absolution error={max_abs}")

0 commit comments

Comments
 (0)