Skip to content

Commit 2662894

Browse files
committed
fix a few things
1 parent 5551b42 commit 2662894

File tree

5 files changed

+27
-9
lines changed

5 files changed

+27
-9
lines changed

_doc/api/helpers/helper.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ onnx_diagnostic.helpers.helper
44

55
.. automodule:: onnx_diagnostic.helpers.helper
66
:no-undoc-members:
7-
:exclude-members: max_diff, string_diff, string_sig, string_type
7+
:exclude-members: flatten_object, max_diff, string_diff, string_sig, string_type

_doc/api/helpers/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ onnx_diagnostic.helpers
1818
rt_helper
1919
torch_test_helper
2020

21+
.. autofunction:: onnx_diagnostic.helpers.flatten_object
22+
2123
.. autofunction:: onnx_diagnostic.helpers.max_diff
2224

2325
.. autofunction:: onnx_diagnostic.helpers.string_diff
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .helper import max_diff, string_diff, string_sig, string_type
1+
from .helper import flatten_object, max_diff, string_diff, string_sig, string_type

onnx_diagnostic/helpers/ort_session.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,27 @@ def __init__(
101101
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
102102
else:
103103
raise ValueError(f"Unexpected value for providers={providers!r}")
104-
sess = onnxruntime.InferenceSession(
105-
sess if isinstance(sess, str) else sess.SerializeToString(),
106-
session_options,
107-
providers=providers,
108-
)
104+
try:
105+
sess = onnxruntime.InferenceSession(
106+
sess if isinstance(sess, str) else sess.SerializeToString(),
107+
session_options,
108+
providers=providers,
109+
)
110+
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
111+
if isinstance(sess, onnx.ModelProto):
112+
debug_path = "_debug_onnxruntine_evaluator_failure.onnx"
113+
onnx.save(
114+
sess,
115+
debug_path,
116+
save_as_external_data=True,
117+
all_tensors_to_one_file=True,
118+
)
119+
else:
120+
debug_path = sess
121+
raise RuntimeError(
122+
f"Unable to create a session stored in {debug_path!r}), "
123+
f"providers={providers}"
124+
) from e
109125
else:
110126
assert (
111127
session_options is None

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def forward(self, x, y):
151151
onnx.save(
152152
proto,
153153
dump_file,
154-
save_as_external_data=False,
155-
all_tensors_to_one_file=True,
154+
save_as_external_data=True,
155+
all_tensors_to_one_file=False,
156156
)
157157

158158

0 commit comments

Comments
 (0)