Skip to content

Commit e8d2d12

Browse files
committed
better
1 parent cd3ff9f commit e8d2d12

File tree

6 files changed

+61
-20
lines changed

6 files changed

+61
-20
lines changed

_unittests/ut_tasks/try_export.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,15 @@ def _config_reduction(config, task):
115115
verbose=1,
116116
stop_if_static=2,
117117
):
118+
if exporter == "onnx-dynamo":
119+
# The exported program in ONNXProgram cannot be restored.
120+
ep2 = torch.export.export(
121+
model.visual,
122+
(),
123+
kwargs=export_inputs,
124+
dynamic_shapes=self.use_dyn_not_str(dynamic_shapes),
125+
)
126+
torch.export.save(ep2, f"{fileep}.backup.pt2")
118127
to_onnx(
119128
model.visual,
120129
kwargs=export_inputs,
@@ -127,7 +136,10 @@ def _config_reduction(config, task):
127136
optimize=True,
128137
)
129138

130-
pt2_file = f"{fileep}.ep.pt2"
139+
pt2_files = [f"{fileep}.backup.pt2", f"{fileep}.ep.pt2", f"{fileep}.pt2"]
140+
pt2_file = [f for f in pt2_files if os.path.exists(f)]
141+
assert pt2_file, f"Unable to find an existing file among {pt2_files}"
142+
pt2_file = pt2_file[0]
131143
# self.assertExists(pt2_file)
132144
# ep = torch.export.load(pt2_file)
133145
# diff = self.max_diff(ep.module()(**export_inputs), model.visual(**export_inputs))

onnx_diagnostic/export/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def to_onnx(
112112
ort_fusions.optimize_for_ort(epo.model)
113113
if filename:
114114
epo.save(filename, external_data=True)
115+
if save_ep:
116+
if isinstance(save_ep, tuple):
117+
save_ep = save_ep[0]
118+
torch.export.save(epo.exported_program, f"{save_ep}.pt2")
115119
return epo
116120

117121
if exporter == "modelbuilder":

onnx_diagnostic/ext_test_case.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1301,7 +1301,7 @@ def assert_onnx_disc(
13011301

13021302
ep = torch.export.load(ep)
13031303
ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
1304-
ep_model = ep.module()
1304+
ep_model = ep.module() # type: ignore[union-attr]
13051305
ep_expected = (
13061306
ep_model(*copy.deepcopy(ep_inputs))
13071307
if isinstance(ep_inputs, tuple)
@@ -1356,6 +1356,11 @@ def max_diff(self, *args, **kwargs):
13561356

13571357
return max_diff(*args, **kwargs)
13581358

1359+
def use_dyn_not_str(self, *args, **kwargs):
1360+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1361+
1362+
return use_dyn_not_str(*args, *kwargs)
1363+
13591364
def subloop(self, *args, verbose: int = 0):
13601365
"Loops over elements and calls :meth:`unittests.TestCase.subTest`."
13611366
if len(args) == 1:

onnx_diagnostic/helpers/ort_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,13 @@ def __init__(
134134

135135
self.sess = sess
136136
self.input_names = [i.name for i in sess.get_inputs()]
137+
assert (
138+
"" not in self.input_names
139+
), f"Input name cannot be empty but input_names={self.input_names}"
137140
self.output_names = [i.name for i in sess.get_outputs()]
141+
assert (
142+
"" not in self.input_names
143+
), f"Output name cannot be empty but output_names={self.output_names}"
138144
self.input_shapes = [i.shape for i in sess.get_inputs()]
139145
self.output_shapes = [i.shape for i in sess.get_outputs()]
140146
self.input_types = [i.type for i in sess.get_inputs()]
@@ -497,6 +503,7 @@ def run_dlpack(
497503
values = ORTC.OrtValueVector()
498504
device = -1
499505
for k, v in feeds.items():
506+
assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
500507
device = max(device, v.get_device())
501508
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
502509
if not v.is_contiguous():

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -564,18 +564,14 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L
564564
onx, sess = self._get_sess(node, inputs)
565565
self._cache[key] = onx, sess
566566

567-
feeds = dict(zip(node.input, inputs))
568-
if "" in feeds:
569-
cls = None
570-
for k, v in feeds.items():
571-
if k != "":
572-
cls = v.__class__
573-
break
574-
assert (
575-
cls is not None
576-
), f"Unable to get input class (array or tensor), feeds={string_type(feeds)}"
577-
feeds[""] = cls([0])
578-
567+
feeds = {}
568+
for i, val in zip(node.input, inputs):
569+
if i == "":
570+
assert (
571+
val is None
572+
), f"input name={i!r} but val={string_type(val, with_shape=True)}"
573+
continue
574+
feeds[i] = val
579575
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
580576
outputs = list(sess.run(None, feeds))
581577
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -567,10 +567,11 @@ def _loop_cmp(
567567
print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}")
568568

569569
placeholders = {node.name for node in ep.graph.nodes if node.op == "placeholder"}
570-
ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers())}
570+
ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)}
571571
placeholders_to_state_dict = {
572572
**{f"p_{name.replace('.', '_')}": name for name in ep.state_dict},
573573
**{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()},
574+
**{f"c_{name.replace('.', '_')}": name for name in ep.tensor_constants},
574575
}
575576
for n in onnx_results:
576577
if n not in placeholders:
@@ -588,6 +589,7 @@ def _loop_cmp(
588589
else:
589590
loop = list(enumerate(ep_graph_nodes))
590591

592+
already_run = set()
591593
ep_durations = {}
592594
yielded_nodes = 0
593595
max_abs = 0
@@ -641,8 +643,8 @@ def _loop_cmp(
641643
yield record
642644
else:
643645
assert node.name in placeholders_to_state_dict, (
644-
f"Unable to find placeholder {node.name!r} in "
645-
f"{sorted(placeholders_to_state_dict)}"
646+
f"Unable to find placeholder {node.name!r} (node.op={node.op!r}), "
647+
f"existing: {sorted(placeholders_to_state_dict)}"
646648
)
647649
torch_results[node.name] = ep_state_dict[placeholders_to_state_dict[node.name]]
648650
if verbose > 1:
@@ -683,6 +685,8 @@ def _loop_cmp(
683685
continue
684686

685687
for i_onnx in range(last_position, max_pos + 1):
688+
if i_onnx in already_run:
689+
continue
686690
node = onx.graph.node[i_onnx]
687691
if verbose > 1:
688692
print(
@@ -695,9 +699,16 @@ def _loop_cmp(
695699
f"mapped {yielded_nodes} maxabs {max_abs:1.5f}"
696700
)
697701
ref = run_cls(node, **run_cls_kwargs)
698-
feeds = {k: onnx_results[k] for k in node.input}
702+
feeds = {k: onnx_results[k] for k in node.input if k}
703+
assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}"
699704
begin = time.perf_counter()
700-
res = ref.run(None, feeds) # type: ignore[attr-defined]
705+
try:
706+
res = ref.run(None, feeds) # type: ignore[attr-defined]
707+
except Exception as e:
708+
raise RuntimeError(
709+
f"Unable to run node {node.op_type}, domain={node.domain} "
710+
f"with inputs={node.input}, feeds={string_type(feeds, **str_kws)}"
711+
) from e
701712
duration = time.perf_counter() - begin
702713
assert (
703714
not has_cuda
@@ -748,6 +759,7 @@ def _loop_cmp(
748759
if tmp.err_abs is not None:
749760
max_abs = max(max_abs, tmp.err_abs)
750761
yield tmp
762+
already_run.add(i_onnx)
751763

752764
last_position = max_pos + 1
753765

@@ -758,14 +770,17 @@ def _loop_cmp(
758770
f"to {len(onx.graph.node)}"
759771
)
760772
for i_onnx in range(last_position, len(onx.graph.node)):
773+
if i_onnx in already_run:
774+
continue
761775
node = onx.graph.node[i_onnx]
762776
if verbose > 1:
763777
print(
764778
f"[run_aligned] run onx.graph.node[{i_onnx}]: "
765779
f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
766780
)
767781
ref = run_cls(node, **run_cls_kwargs)
768-
feeds = {k: onnx_results[k] for k in node.input}
782+
feeds = {k: onnx_results[k] for k in node.input if k}
783+
assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}"
769784
begin = time.perf_counter()
770785
res = ref.run(None, feeds) # type: ignore[attr-defined]
771786
duration = time.perf_counter() - begin
@@ -800,6 +815,8 @@ def _loop_cmp(
800815
if tmp.err_abs is not None:
801816
max_abs = max(max_abs, tmp.err_abs)
802817
yield tmp
818+
already_run.add(i_onnx)
819+
803820
if verbose:
804821
print(f"[run_aligned] done with {yielded_nodes} mapped nodes")
805822
print(f"[run_aligned] max absolution error={max_abs}")

0 commit comments

Comments
 (0)