Skip to content

Commit 849049b

Browse files
committed
Improves investigate_onnxruntime_issue
1 parent cd111f5 commit 849049b

File tree

6 files changed

+197
-16
lines changed

6 files changed

+197
-16
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.2.0
5+
+++++
6+
7+
* :pr:`7`: improves function ``investigate_onnxruntime_issue``
8+
49
0.1.0
510
+++++
611

_doc/examples/plot_export_tiny_llm.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
"""
22
.. _l-plot-tiny-llm-export:
33
4-
Export LLM with dynamic shapes
5-
==============================
4+
Steel method forward to guess the dynamic shapes
5+
================================================
6+
7+
Inputs are always dynamic with LLMs that is why dyanmic shapes
8+
needs to be specified when a LLM is exported with:func:`torch.export.export`.
9+
Most of the examples on :epkg:`HuggingFace` use method
10+
:meth:`transformers.GenerationMixin.generate` but we only want to
11+
export the model and its method ``forward``.
12+
13+
That example shows to guess the inputs of this method even though the model
14+
is executed through meth ``generate``.
615
716
We focus on the model
817
`Tiny-LLM <https://huggingface.co/arnir0/Tiny-LLM>`_.
918
To avoid downloading any weigths, we write a function creating a
1019
random model based on the same architecture.
1120
12-
Guess the cache dimension
13-
+++++++++++++++++++++++++
21+
Steel the forward method
22+
++++++++++++++++++++++++
1423
1524
The first step is to guess the dummy inputs.
1625
Let's use the true model for that.
1726
We use the dummy example from the model page.
1827
"""
1928

2029
import copy
30+
import pprint
2131
import torch
2232
import transformers
2333
from onnx_diagnostic.helpers import string_type
@@ -64,8 +74,13 @@ def _forward_(*args, _f=None, **kwargs):
6474
model.forward = keep_model_forward
6575

6676
# %%
67-
# The model creation
68-
# ++++++++++++++++++
77+
# Untrained model
78+
# +++++++++++++++
79+
#
80+
# This part can skipped if you are only interested in exporting
81+
# the original model. It is useful to create a unit test to ensure
82+
# a specific architecture can be exported despite the many changes
83+
# brought to :epkg:`torch` or :epkg:`transformers`.
6984
#
7085
# Let's create an untrained model using the config file provided
7186
# `config.json <https://huggingface.co/arnir0/Tiny-LLM/blob/main/config.json>`_
@@ -126,6 +141,22 @@ def _forward_(*args, _f=None, **kwargs):
126141
# ++++++++++++++++++++++++++
127142
#
128143
# Let's use the same dummy inputs but we use the downloaded model.
144+
# Dummy inputs and dynamic shapes are created by function
145+
# :func:`onnx_diagnostic.torch_models.llms.get_tiny_llm`.
146+
147+
data = get_tiny_llm()
148+
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
149+
150+
# %%
151+
# Let's print the inputs.
152+
153+
print(string_type(inputs, with_shape=True))
154+
155+
# %% Let's print the dynamic shapes
156+
pprint.pprint(dynamic_shapes)
157+
158+
# %%
159+
# And Let's finally export.
129160

130161
try:
131162
ep = torch.export.export(model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
.. _l-plot-failing-model-extract:
3+
4+
Find where a model failing by running submodels
5+
===============================================
6+
7+
Let's assume :epkg:`onnxruntime` crashes without telling why or where.
8+
The first thing is do is to locate where. For that, we extract every submodel
9+
starting from the inputs and running the first *n* nodes of the model.
10+
The model is likely to fail for some *n*. Then the failing is known.
11+
12+
A failing model
13+
+++++++++++++++
14+
15+
The issue here is a an operator ``Cast`` trying to convert a result
16+
into a non-existing type.
17+
"""
18+
19+
import numpy as np
20+
import onnx
21+
import onnx.helper as oh
22+
import onnxruntime
23+
from onnx_diagnostic.helpers import from_array_extended
24+
from onnx_diagnostic.ort_session import investigate_onnxruntime_issue
25+
26+
TFLOAT = onnx.TensorProto.FLOAT
27+
28+
model = oh.make_model(
29+
oh.make_graph(
30+
[
31+
oh.make_node("Mul", ["X", "Y"], ["xy"], name="n0"),
32+
oh.make_node("Sigmoid", ["xy"], ["sy"], name="n1"),
33+
oh.make_node("Add", ["sy", "one"], ["C"], name="n2"),
34+
oh.make_node("Cast", ["C"], ["X999"], to=999, name="failing"),
35+
oh.make_node("CastLike", ["X999", "Y"], ["Z"], name="n4"),
36+
],
37+
"nd",
38+
[
39+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
40+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
41+
],
42+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
43+
[from_array_extended(np.array([1], dtype=np.float32), name="one")],
44+
),
45+
opset_imports=[oh.make_opsetid("", 18)],
46+
ir_version=9,
47+
)
48+
49+
# %%
50+
# We check it is failing.
51+
52+
try:
53+
onnxruntime.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
54+
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
55+
print(e)
56+
57+
58+
# %%
59+
# Shape Inference
60+
# +++++++++++++++
61+
#
62+
# Building submodels requires to known the output type.
63+
# We run shape inference on the model.
64+
shaped_model = onnx.shape_inference.infer_shapes(model)
65+
66+
67+
# %%
68+
# Looping over the nodes
69+
# ++++++++++++++++++++++
70+
#
71+
#
72+
73+
failing = investigate_onnxruntime_issue(shaped_model, providers="cpu", verbose=1, quiet=True)
74+
75+
# %%
76+
# Let's print the failing node.
77+
print(failing)
78+
79+
80+
# %%
81+
# Detect an issue with shape Inference
82+
# ++++++++++++++++++++++++++++++++++++
83+
#
84+
# We could have caught the error sooner by asking shape inference
85+
# to raise an exception if one node could not be processed.
86+
# It means either the node is a custom node
87+
# and shape inference has no way to guess the output type and shape
88+
# for this node or shape inference failed.
89+
90+
try:
91+
onnx.shape_inference.infer_shapes(model, strict_mode=True)
92+
except onnx.onnx_cpp2py_export.shape_inference.InferenceError as e:
93+
print(e)

_doc/index.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,11 @@ Source are `sdpython/onnx-diagnostic
4545
CHANGELOGS
4646
license
4747

48+
**Enlightening Examples**
4849

49-
**Some usefuls tools**
50+
* :ref:`l-plot-tiny-llm-export`
51+
52+
**Some Usefuls Tools**
5053

5154
.. code-block:: python
5255
@@ -135,7 +138,6 @@ Size of the package:
135138
gr = df[["dir", "ext", "lines", "chars"]].groupby(["ext", "dir"]).sum()
136139
print(gr)
137140

138-
Older versions
139-
++++++++++++++
141+
**Older versions**
140142

141143
* `0.1.0 <../v0.1.0/index.html>`_

_unittests/ut_xrun_doc/test_ort_session.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,17 @@ def test_investigate_onnxruntime_issue_torch(self):
182182
dump_filename="test_investigate_onnxruntime_issue_torch.onnx",
183183
)
184184

185+
@hide_stdout()
186+
def test_investigate_onnxruntime_issue_torch_quiet(self):
187+
model, feeds, _expected = self._get_model()
188+
investigate_onnxruntime_issue(
189+
model,
190+
feeds=feeds,
191+
verbose=10,
192+
dump_filename="test_investigate_onnxruntime_issue_torch.onnx",
193+
quiet=True,
194+
)
195+
185196
@hide_stdout()
186197
def test_investigate_onnxruntime_issue_numpy(self):
187198
model, feeds, _expected = self._get_model()

onnx_diagnostic/ort_session.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def investigate_onnxruntime_issue(
408408
verbose: int = 0,
409409
dump_filename: Optional[str] = None,
410410
infer_shapes: bool = True,
411+
quiet: bool = False,
411412
):
412413
"""
413414
Invgestigates a crashing model. It tries every node until
@@ -433,6 +434,8 @@ def investigate_onnxruntime_issue(
433434
:param verbosity: verbosity level
434435
:param dump_filename: if not None, the function dumps the last model run
435436
:param infer_shapes: run shape inference
437+
:param quiet: if True, raises an exception, False, just stops and
438+
return the failing node
436439
437440
The most simple use:
438441
@@ -531,7 +534,19 @@ def investigate_onnxruntime_issue(
531534
f"{', '.join(node.output)}"
532535
)
533536
e = onnx.utils.Extractor(onx)
534-
extracted = e.extract_model(input_names, node.output)
537+
if quiet:
538+
try:
539+
extracted = e.extract_model(input_names, node.output)
540+
except Exception as e:
541+
if verbose > 0:
542+
print(
543+
f"[investigate_onnxruntime_issue] cannot extract "
544+
f"model at node {i} due to {e}"
545+
)
546+
return node
547+
else:
548+
extracted = e.extract_model(input_names, node.output)
549+
535550
if dump_filename:
536551
if verbose > 1:
537552
print(f"[investigate_onnxruntime_issue] save into {dump_filename}")
@@ -540,11 +555,11 @@ def investigate_onnxruntime_issue(
540555
if verbose > 1:
541556
print("[investigate_onnxruntime_issue] create the session")
542557

543-
if onnx_to_session:
544-
sess = onnx_to_session(onx)
545-
else:
546-
sess = cls(
547-
extracted,
558+
def _make_session(proto):
559+
if onnx_to_session:
560+
return onnx_to_session(proto)
561+
return cls(
562+
proto,
548563
session_options=session_options,
549564
providers=providers,
550565
nvtx=nvtx,
@@ -557,6 +572,19 @@ def investigate_onnxruntime_issue(
557572
use_training_api=use_training_api,
558573
)
559574

575+
if quiet:
576+
try:
577+
sess = _make_session(extracted)
578+
except Exception as e:
579+
if verbose > 0:
580+
print(
581+
f"[investigate_onnxruntime_issue] cannot create session "
582+
f"at node {i} due to {e}"
583+
)
584+
return node
585+
else:
586+
sess = _make_session(extracted)
587+
560588
if not feeds:
561589
if verbose > 1:
562590
print("[investigate_onnxruntime_issue] session created")
@@ -565,7 +593,18 @@ def investigate_onnxruntime_issue(
565593
if verbose > 1:
566594
print("[investigate_onnxruntime_issue] running session")
567595

568-
sess.run(None, feeds)
596+
if quiet:
597+
try:
598+
sess.run(None, feeds)
599+
except Exception as e:
600+
if verbose > 0:
601+
print(
602+
f"[investigate_onnxruntime_issue] cannot run session "
603+
f"at node {i} due to {e}"
604+
)
605+
return node
606+
else:
607+
sess.run(None, feeds)
569608

570609
if verbose > 0:
571610
print("[investigate_onnxruntime_issue] done.")

0 commit comments

Comments
 (0)