Skip to content

Commit 25f207a

Browse files
committed
improve examples
1 parent 62db406 commit 25f207a

File tree

12 files changed

+150
-29
lines changed

12 files changed

+150
-29
lines changed

_doc/api/reference/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ OnnxruntimeEvaluator
3030
.. autoclass:: onnx_diagnostic.reference.OnnxruntimeEvaluator
3131
:members:
3232

33-
ReportResultsComparison
34-
+++++++++++++++++++++++
33+
ReportResultComparison
34+
++++++++++++++++++++++
3535

36-
.. autoclass:: onnx_diagnostic.reference.ReportResultsComparison
36+
.. autoclass:: onnx_diagnostic.reference.ReportResultComparison
3737
:members:
3838

3939
TorchOnnxEvaluator

_doc/api/reference/report_results_comparison.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ onnx_diagnostic.reference.report_results_comparison
55
.. automodule:: onnx_diagnostic.reference.report_results_comparison
66
:members:
77
:no-undoc-members:
8-
:exclude-members: ReportResultsComparison
8+
:exclude-members: ReportResultComparison

_doc/conf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sphinx_runpython.conf_helper import has_dvipng, has_dvisvgm
66
import torch
77
from onnx_diagnostic import __version__
8+
from onnx_diagnostic.doc import update_version_package
89

910
extensions = [
1011
"sphinx.ext.autodoc",
@@ -40,8 +41,8 @@
4041
project = "onnx-diagnostic"
4142
copyright = "2025"
4243
author = "Xavier Dupré"
43-
version = __version__
44-
release = __version__
44+
version = update_version_package(__version__)
45+
release = version
4546
language = "en"
4647
exclude_patterns = []
4748
pygments_style = "sphinx"

_doc/examples/plot_dump_intermediate_results.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,30 @@
1919
2020
See :func:`onnx_diagnostic.helpers.torch_helper.dummy_llm`
2121
for its definition. It is mostly used for unit test or example.
22-
2322
"""
2423

24+
import numpy as np
25+
import pandas
2526
import onnx
2627
import torch
28+
import onnxruntime
2729
from onnx_array_api.plotting.graphviz_helper import plot_dot
2830
from onnx_diagnostic import doc
29-
from onnx_diagnostic.helpers import string_type
30-
from onnx_diagnostic.helpers.torch_helper import dummy_llm
31+
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
32+
from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward
3133
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
32-
from onnx_diagnostic.helpers.torch_helper import steal_forward
34+
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ReportResultComparison
3335

3436

3537
model, inputs, ds = dummy_llm(dynamic_shapes=True)
3638

39+
# %%
40+
# We use float16.
41+
model = model.to(torch.float16)
42+
43+
# %%
44+
# Let's check.
45+
3746
print(f"type(model)={type(model)}")
3847
print(f"inputs={string_type(inputs, with_shape=True)}")
3948
print(f"ds={string_type(ds, with_shape=True)}")
@@ -65,7 +74,7 @@
6574
verbose=1,
6675
storage_limit=2**28,
6776
):
68-
model(*inputs)
77+
expected = model(*inputs)
6978

7079

7180
# %%
@@ -124,7 +133,74 @@
124133
epo.save("plot_dump_intermediate_results.onnx")
125134

126135
# %%
127-
# It looks like the following.
136+
# Discrepancies
137+
# +++++++++++++
138+
#
139+
# We have a torch model, intermediate results and an ONNX graph
140+
# equivalent to the torch model.
141+
# Let's see how we can check the discrepancies.
142+
# First the discrepancies of the whole model.
143+
144+
sess = onnxruntime.InferenceSession(
145+
"plot_dump_intermediate_results.onnx", providers=["CPUExecutionProvider"]
146+
)
147+
feeds = dict(
148+
zip([i.name for i in sess.get_inputs()], [t.detach().cpu().numpy() for t in inputs])
149+
)
150+
got = sess.run(None, feeds)
151+
diff = max_diff(expected, got)
152+
print(f"discrepancies torch/ORT: {string_diff(diff)}")
153+
154+
# %%
155+
# What about intermediate results?
156+
# Let's use a runtime still based on :epkg:`onnxruntime`
157+
# running an eager evaluation.
158+
159+
sess_eager = OnnxruntimeEvaluator(
160+
"plot_dump_intermediate_results.onnx",
161+
providers=["CPUExecutionProvider"],
162+
torch_or_numpy=True,
163+
)
164+
feeds_tensor = dict(zip([i.name for i in sess.get_inputs()], inputs))
165+
got = sess_eager.run(None, feeds_tensor)
166+
diff = max_diff(expected, got)
167+
print(f"discrepancies torch/eager ORT: {string_diff(diff)}")
168+
169+
# %%
170+
# They are almost the same. That's good.
171+
# Let's now dig into the intermediate results.
172+
# They are compared to the outputs stored in saved_tensors
173+
# during the execution of the model.
174+
baseline = {}
175+
for k, v in saved_tensors.items():
176+
if k[-1] == "I": # inputs are excluded
177+
continue
178+
if isinstance(v, torch.Tensor):
179+
baseline[f"{k[0]}.{k[1]}".replace("model.decoder", "decoder")] = v
180+
181+
report_cmp = ReportResultComparison(baseline)
182+
sess_eager.run(None, feeds_tensor, report_cmp=report_cmp)
183+
184+
# %%
185+
# Let's see the results.
186+
187+
data = report_cmp.data
188+
df = pandas.DataFrame(data)
189+
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
190+
print(piv)
191+
192+
# %%
193+
# Let's clean a little bit.
194+
piv[piv >= 1] = np.nan
195+
print(piv.dropna(axis=0, how="all"))
196+
197+
# %%
198+
# We can identity which results is mapped to which expected tensor.
199+
200+
# %%
201+
# Picture of the model
202+
# ++++++++++++++++++++
203+
128204
onx = onnx.load("plot_dump_intermediate_results.onnx")
129205
plot_dot(onx)
130206

_doc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ With the following versions:
236236
import ml_dtypes
237237
import sklearn
238238
import onnx
239+
import onnx_ir
239240
import onnxruntime
240241
import onnxscript
241242
import torch
@@ -247,6 +248,7 @@ With the following versions:
247248
ml_dtypes,
248249
sklearn,
249250
onnx,
251+
onnx_ir,
250252
onnxruntime,
251253
onnxscript,
252254
torch,

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from onnx_diagnostic.reference import (
1010
OnnxruntimeEvaluator,
1111
ExtendedReferenceEvaluator,
12-
ReportResultsComparison,
12+
ReportResultComparison,
1313
)
1414

1515
try:
@@ -217,7 +217,7 @@ def test_report_results_comparison_ort(self):
217217
)
218218
x = torch.rand(5, 6, dtype=torch.float32)
219219
onnx.checker.check_model(model)
220-
cmp = ReportResultsComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
220+
cmp = ReportResultComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
221221
cmp.clear()
222222
feeds = dict(zip([i.name for i in model.graph.input], (x,)))
223223
rt = OnnxruntimeEvaluator(model, verbose=10)

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from onnx_diagnostic.reference import (
1212
ExtendedReferenceEvaluator,
1313
TorchOnnxEvaluator,
14-
ReportResultsComparison,
14+
ReportResultComparison,
1515
)
1616
from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
1717
from onnx_diagnostic.reference.torch_evaluator import get_kernels
@@ -1496,7 +1496,7 @@ def test_report_results_comparison(self):
14961496
)
14971497
x = torch.rand(5, 6, dtype=torch.float32)
14981498
onnx.checker.check_model(model)
1499-
cmp = ReportResultsComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
1499+
cmp = ReportResultComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
15001500
cmp.clear()
15011501
feeds = dict(zip([i.name for i in model.graph.input], (x,)))
15021502
rt = TorchOnnxEvaluator(model, verbose=10)

onnx_diagnostic/doc.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,28 @@
22
import numpy as np
33

44

5+
def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
6+
"""Returns the latest published version."""
7+
8+
import requests
9+
10+
url = f"https://pypi.org/pypi/{package_name}/json"
11+
response = requests.get(url)
12+
13+
assert response.status_code == 200, f"Unable to retrieve the version response={response}"
14+
data = response.json()
15+
version = data["info"]["version"]
16+
return version
17+
18+
19+
def update_version_package(version: str, package_name="onnx-diagnostic") -> str:
20+
"Adds dev if the major version is different from the latest published one."
21+
released = get_latest_pypi_version(package_name)
22+
shorten_r = ".".join(released.split(".")[:2])
23+
shorten_v = ".".join(version.split(".")[:2])
24+
return version if shorten_r == shorten_v else f"{shorten_v}.dev"
25+
26+
527
def reset_torch_transformers(gallery_conf, fname):
628
"Resets torch dynamo for :epkg:`sphinx-gallery`."
729
import matplotlib.pyplot as plt
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .evaluator import ExtendedReferenceEvaluator
22
from .ort_evaluator import OnnxruntimeEvaluator
33
from .torch_evaluator import TorchOnnxEvaluator
4-
from .report_results_comparison import ReportResultsComparison
4+
from .report_results_comparison import ReportResultComparison

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
InferenceSessionForNumpy,
2323
_InferenceSession,
2424
)
25-
from .report_results_comparison import ReportResultsComparison
25+
from ..helpers.torch_helper import to_tensor
26+
from .report_results_comparison import ReportResultComparison
2627
from .evaluator import ExtendedReferenceEvaluator
2728

2829

@@ -51,6 +52,8 @@ class OnnxruntimeEvaluator:
5152
:param ir_version: ir version to use when unknown
5253
:param opsets: opsets to use when unknown
5354
:param whole: if True, do not split node by node
55+
:param torch_or_numpy: force the use of one of them, Ture for torch,
56+
False for numpy, None to let the class choose
5457
"""
5558

5659
def __init__(
@@ -73,6 +76,7 @@ def __init__(
7376
ir_version: int = 10,
7477
opsets: Optional[Union[int, Dict[str, int]]] = None,
7578
whole: bool = False,
79+
torch_or_numpy: Optional[bool] = None,
7680
):
7781
if isinstance(proto, str):
7882
self.proto: Proto = load(proto)
@@ -104,8 +108,10 @@ def __init__(
104108
disable_aot_function_inlining=disable_aot_function_inlining,
105109
use_training_api=use_training_api,
106110
)
111+
self.to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
107112

108113
self.verbose = verbose
114+
self.torch_or_numpy = torch_or_numpy
109115
self.sess_: Optional[_InferenceSession] = None
110116
if whole:
111117
self.nodes: Optional[List[NodeProto]] = None
@@ -124,7 +130,10 @@ def __init__(
124130
)
125131
)
126132
self.rt_inits_ = (
127-
{init.name: to_array_extended(init) for init in self.proto.graph.initializer}
133+
{
134+
init.name: self.to_tensor_or_array(init)
135+
for init in self.proto.graph.initializer
136+
}
128137
if hasattr(self.proto, "graph")
129138
else {}
130139
)
@@ -192,13 +201,14 @@ def _log_arg(self, a: Any) -> Any:
192201
return a
193202
device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
194203
if hasattr(a, "shape"):
204+
prefix = "A:" if hasattr(a, "astype") else "T:"
195205
if self.verbose < 4: # noqa: PLR2004
196-
return f"{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]"
206+
return f"{prefix}{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]"
197207
elements = a.ravel().tolist()
198208
if len(elements) > 10: # noqa: PLR2004
199209
elements = elements[:10]
200-
return f"{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..."
201-
return f"{device}{a.dtype}:{a.shape}:{elements}"
210+
return f"{prefix}{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..."
211+
return f"{prefix}{device}{a.dtype}:{a.shape}:{elements}"
202212
if hasattr(a, "append"):
203213
return ", ".join(map(self._log_arg, a))
204214
return a
@@ -216,7 +226,7 @@ def run(
216226
outputs: Optional[List[str]],
217227
feed_inputs: Dict[str, Any],
218228
intermediate: bool = False,
219-
report_cmp: Optional[ReportResultsComparison] = None,
229+
report_cmp: Optional[ReportResultComparison] = None,
220230
) -> Union[Dict[str, Any], List[Any]]:
221231
"""
222232
Runs the model.
@@ -228,7 +238,7 @@ def run(
228238
:param report_cmp: used as a reference,
229239
every intermediate results is compare to every existing one,
230240
if not empty, it is an instance of
231-
:class:`onnx_diagnostic.reference.ReportResultsComparison`
241+
:class:`onnx_diagnostic.reference.ReportResultComparison`
232242
:return: outputs, as a list if return_all is False,
233243
as a dictionary if return_all is True
234244
"""
@@ -437,8 +447,12 @@ def _get_sess(
437447
cls = (
438448
InferenceSessionForNumpy
439449
if any(isinstance(i, np.ndarray) for i in inputs)
450+
and (not isinstance(self.torch_or_numpy, bool) or not self.torch_or_numpy)
440451
else InferenceSessionForTorch
441452
)
453+
assert (
454+
cls is InferenceSessionForTorch
455+
), f"ERROR: {string_type(inputs, with_shape=True)}"
442456
try:
443457
sess = cls(onx, **self.session_kwargs)
444458
except (
@@ -497,6 +511,7 @@ def _get_sess_if(
497511
verbose=self.verbose,
498512
ir_version=self.ir_version,
499513
opsets=self.opsets,
514+
torch_or_numpy=self.torch_or_numpy,
500515
**self.session_kwargs,
501516
)
502517
return onx, sess
@@ -511,6 +526,7 @@ def _get_sess_local(
511526
verbose=self.verbose,
512527
ir_version=self.ir_version,
513528
opsets=self.opsets,
529+
torch_or_numpy=self.torch_or_numpy,
514530
**self.session_kwargs,
515531
)
516532
return ev.proto, sess
@@ -586,6 +602,7 @@ def _get_sess_scan(
586602
verbose=self.verbose,
587603
ir_version=self.ir_version,
588604
opsets=self.opsets,
605+
torch_or_numpy=self.torch_or_numpy,
589606
whole=True,
590607
**self.session_kwargs,
591608
)

0 commit comments

Comments
 (0)