Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Change Logs
===========

0.7.0
+++++

* :pr:`143`: compares intermediate results

0.6.3
+++++

Expand Down
7 changes: 7 additions & 0 deletions _doc/api/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ onnx_diagnostic.reference
evaluator
quantized_tensor
ort_evaluator
report_results_comparison
torch_evaluator

ExtendedReferenceEvaluator
Expand All @@ -29,6 +30,12 @@ OnnxruntimeEvaluator
.. autoclass:: onnx_diagnostic.reference.OnnxruntimeEvaluator
:members:

ReportResultComparison
++++++++++++++++++++++

.. autoclass:: onnx_diagnostic.reference.ReportResultComparison
:members:

TorchOnnxEvaluator
++++++++++++++++++

Expand Down
8 changes: 8 additions & 0 deletions _doc/api/reference/report_results_comparison.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

onnx_diagnostic.reference.report_results_comparison
===================================================

.. automodule:: onnx_diagnostic.reference.report_results_comparison
:members:
:no-undoc-members:
:exclude-members: ReportResultComparison
5 changes: 3 additions & 2 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sphinx_runpython.conf_helper import has_dvipng, has_dvisvgm
import torch
from onnx_diagnostic import __version__
from onnx_diagnostic.doc import update_version_package

extensions = [
"sphinx.ext.autodoc",
Expand Down Expand Up @@ -40,8 +41,8 @@
project = "onnx-diagnostic"
copyright = "2025"
author = "Xavier Dupré"
version = __version__
release = __version__
version = update_version_package(__version__)
release = version
language = "en"
exclude_patterns = []
pygments_style = "sphinx"
Expand Down
88 changes: 82 additions & 6 deletions _doc/examples/plot_dump_intermediate_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,30 @@

See :func:`onnx_diagnostic.helpers.torch_helper.dummy_llm`
for its definition. It is mostly used for unit test or example.

"""

import numpy as np
import pandas
import onnx
import torch
import onnxruntime
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ReportResultComparison


model, inputs, ds = dummy_llm(dynamic_shapes=True)

# %%
# We use float16.
model = model.to(torch.float16)

# %%
# Let's check.

print(f"type(model)={type(model)}")
print(f"inputs={string_type(inputs, with_shape=True)}")
print(f"ds={string_type(ds, with_shape=True)}")
Expand Down Expand Up @@ -65,7 +74,7 @@
verbose=1,
storage_limit=2**28,
):
model(*inputs)
expected = model(*inputs)


# %%
Expand Down Expand Up @@ -124,7 +133,74 @@
epo.save("plot_dump_intermediate_results.onnx")

# %%
# It looks like the following.
# Discrepancies
# +++++++++++++
#
# We have a torch model, intermediate results and an ONNX graph
# equivalent to the torch model.
# Let's see how we can check the discrepancies.
# First the discrepancies of the whole model.

sess = onnxruntime.InferenceSession(
"plot_dump_intermediate_results.onnx", providers=["CPUExecutionProvider"]
)
feeds = dict(
zip([i.name for i in sess.get_inputs()], [t.detach().cpu().numpy() for t in inputs])
)
got = sess.run(None, feeds)
diff = max_diff(expected, got)
print(f"discrepancies torch/ORT: {string_diff(diff)}")

# %%
# What about intermediate results?
# Let's use a runtime still based on :epkg:`onnxruntime`
# running an eager evaluation.

sess_eager = OnnxruntimeEvaluator(
"plot_dump_intermediate_results.onnx",
providers=["CPUExecutionProvider"],
torch_or_numpy=True,
)
feeds_tensor = dict(zip([i.name for i in sess.get_inputs()], inputs))
got = sess_eager.run(None, feeds_tensor)
diff = max_diff(expected, got)
print(f"discrepancies torch/eager ORT: {string_diff(diff)}")

# %%
# They are almost the same. That's good.
# Let's now dig into the intermediate results.
# They are compared to the outputs stored in saved_tensors
# during the execution of the model.
baseline = {}
for k, v in saved_tensors.items():
if k[-1] == "I": # inputs are excluded
continue
if isinstance(v, torch.Tensor):
baseline[f"{k[0]}.{k[1]}".replace("model.decoder", "decoder")] = v

report_cmp = ReportResultComparison(baseline)
sess_eager.run(None, feeds_tensor, report_cmp=report_cmp)

# %%
# Let's see the results.

data = report_cmp.data
df = pandas.DataFrame(data)
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
print(piv)

# %%
# Let's clean a little bit.
piv[piv >= 1] = np.nan
print(piv.dropna(axis=0, how="all"))

# %%
# We can identity which results is mapped to which expected tensor.

# %%
# Picture of the model
# ++++++++++++++++++++

onx = onnx.load("plot_dump_intermediate_results.onnx")
plot_dot(onx)

Expand Down
6 changes: 3 additions & 3 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,8 @@ The function replaces dynamic dimensions defined as strings by
Older versions
++++++++++++++

* `0.7.0 <../v0.7.0/index.html>`_
* `0.6.3 <../v0.6.3/index.html>`_
* `0.6.2 <../v0.6.2/index.html>`_
* `0.6.1 <../v0.6.1/index.html>`_
* `0.6.0 <../v0.6.0/index.html>`_
* `0.5.0 <../v0.5.0/index.html>`_
* `0.4.4 <../v0.4.4/index.html>`_
* `0.3.0 <../v0.3.0/index.html>`_
Expand All @@ -238,6 +236,7 @@ With the following versions:
import ml_dtypes
import sklearn
import onnx
import onnx_ir
import onnxruntime
import onnxscript
import torch
Expand All @@ -249,6 +248,7 @@ With the following versions:
ml_dtypes,
sklearn,
onnx,
onnx_ir,
onnxruntime,
onnxscript,
torch,
Expand Down
40 changes: 38 additions & 2 deletions _unittests/ut_reference/test_onnxruntime_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@
import onnx.helper as oh
import torch
import onnxruntime
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator
from onnx_diagnostic.reference import (
OnnxruntimeEvaluator,
ExtendedReferenceEvaluator,
ReportResultComparison,
)

try:
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
except ImportError:
to_onnx = None


TFLOAT = onnx.TensorProto.FLOAT


class TestOnnxruntimeEvaluator(ExtTestCase):
def test_ort_eval_scan_cdist_add(self):

Expand Down Expand Up @@ -190,6 +197,35 @@ def test_ort_eval_loop(self):
got = ref.run(None, feeds)
self.assertEqualArray(expected, got[0])

@hide_stdout()
def test_report_results_comparison_ort(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Cos", ["X"], ["nx"]),
oh.make_node("Sin", ["nx"], ["t"]),
oh.make_node("Exp", ["t"], ["u"]),
oh.make_node("Log", ["u"], ["uZ"]),
oh.make_node("Erf", ["uZ"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
),
ir_version=9,
opset_imports=[oh.make_opsetid("", 18)],
)
x = torch.rand(5, 6, dtype=torch.float32)
onnx.checker.check_model(model)
cmp = ReportResultComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
cmp.clear()
feeds = dict(zip([i.name for i in model.graph.input], (x,)))
rt = OnnxruntimeEvaluator(model, verbose=10)
rt.run(None, feeds, report_cmp=cmp)
d = {k: d["abs"] for k, d in cmp.value.items()}
self.assertLess(d[(0, "nx"), "r_cos"], 1e-6)
self.assertLess(d[(2, "u"), "r_exp"], 1e-6)


if __name__ == "__main__":
unittest.main(verbosity=2)
44 changes: 42 additions & 2 deletions _unittests/ut_reference/test_torch_onnx_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import unittest
import numpy as np
import pandas
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchOnnxEvaluator
from onnx_diagnostic.reference import (
ExtendedReferenceEvaluator,
TorchOnnxEvaluator,
ReportResultComparison,
)
from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
from onnx_diagnostic.reference.torch_evaluator import get_kernels

Expand Down Expand Up @@ -1471,6 +1476,41 @@ def run(self, x, scale, bias=None):
self.assertEqualAny(expected, got, atol=1e-3)
self.assertEqual([1], LayerNormalizationOrt._shared)

@hide_stdout()
def test_report_results_comparison(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Cos", ["X"], ["nx"]),
oh.make_node("Sin", ["nx"], ["t"]),
oh.make_node("Exp", ["t"], ["u"]),
oh.make_node("Log", ["u"], ["uZ"]),
oh.make_node("Erf", ["uZ"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
),
ir_version=9,
opset_imports=[oh.make_opsetid("", 18)],
)
x = torch.rand(5, 6, dtype=torch.float32)
onnx.checker.check_model(model)
cmp = ReportResultComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
cmp.clear()
feeds = dict(zip([i.name for i in model.graph.input], (x,)))
rt = TorchOnnxEvaluator(model, verbose=10)
rt.run(None, feeds, report_cmp=cmp)
d = {k: d["abs"] for k, d in cmp.value.items()}
self.assertEqual(d[(0, "nx"), "r_cos"], 0)
self.assertEqual(d[(2, "u"), "r_exp"], 0)
data = cmp.data
self.assertIsInstance(data, list)
df = pandas.DataFrame(data)
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
self.assertEqual(list(piv.columns), ["r_cos", "r_exp", "r_x"])
self.assertEqual(list(piv.index), [(0, "nx"), (1, "t"), (2, "u"), (3, "uZ"), (4, "Z")])


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion onnx_diagnostic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
Functions, classes to dig into a model when this one is right, slow, wrong...
"""

__version__ = "0.6.3"
__version__ = "0.7.0"
__author__ = "Xavier Dupré"
22 changes: 22 additions & 0 deletions onnx_diagnostic/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@
import numpy as np


def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
"""Returns the latest published version."""

import requests

url = f"https://pypi.org/pypi/{package_name}/json"
response = requests.get(url)

assert response.status_code == 200, f"Unable to retrieve the version response={response}"
data = response.json()
version = data["info"]["version"]
return version


def update_version_package(version: str, package_name="onnx-diagnostic") -> str:
"Adds dev if the major version is different from the latest published one."
released = get_latest_pypi_version(package_name)
shorten_r = ".".join(released.split(".")[:2])
shorten_v = ".".join(version.split(".")[:2])
return version if shorten_r == shorten_v else f"{shorten_v}.dev"


def reset_torch_transformers(gallery_conf, fname):
"Resets torch dynamo for :epkg:`sphinx-gallery`."
import matplotlib.pyplot as plt
Expand Down
1 change: 1 addition & 0 deletions onnx_diagnostic/reference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .evaluator import ExtendedReferenceEvaluator
from .ort_evaluator import OnnxruntimeEvaluator
from .torch_evaluator import TorchOnnxEvaluator
from .report_results_comparison import ReportResultComparison
Loading
Loading