Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Change Logs
===========

0.7.0
+++++

* :pr:`143`: compares intermediate results

0.6.3
+++++

Expand Down
7 changes: 7 additions & 0 deletions _doc/api/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ onnx_diagnostic.reference
evaluator
quantized_tensor
ort_evaluator
report_results_comparison
torch_evaluator

ExtendedReferenceEvaluator
Expand All @@ -29,6 +30,12 @@ OnnxruntimeEvaluator
.. autoclass:: onnx_diagnostic.reference.OnnxruntimeEvaluator
:members:

ReportResultComparison
++++++++++++++++++++++

.. autoclass:: onnx_diagnostic.reference.ReportResultComparison
:members:

TorchOnnxEvaluator
++++++++++++++++++

Expand Down
8 changes: 8 additions & 0 deletions _doc/api/reference/report_results_comparison.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

onnx_diagnostic.reference.report_results_comparison
===================================================

.. automodule:: onnx_diagnostic.reference.report_results_comparison
:members:
:no-undoc-members:
:exclude-members: ReportResultComparison
5 changes: 3 additions & 2 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sphinx_runpython.conf_helper import has_dvipng, has_dvisvgm
import torch
from onnx_diagnostic import __version__
from onnx_diagnostic.doc import update_version_package

extensions = [
"sphinx.ext.autodoc",
Expand Down Expand Up @@ -40,8 +41,8 @@
project = "onnx-diagnostic"
copyright = "2025"
author = "Xavier Dupré"
version = __version__
release = __version__
version = update_version_package(__version__)
release = version
language = "en"
exclude_patterns = []
pygments_style = "sphinx"
Expand Down
88 changes: 82 additions & 6 deletions _doc/examples/plot_dump_intermediate_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,30 @@

See :func:`onnx_diagnostic.helpers.torch_helper.dummy_llm`
for its definition. It is mostly used for unit test or example.

"""

import numpy as np
import pandas
import onnx
import torch
import onnxruntime
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ReportResultComparison


model, inputs, ds = dummy_llm(dynamic_shapes=True)

# %%
# We use float16.
model = model.to(torch.float16)

# %%
# Let's check.

print(f"type(model)={type(model)}")
print(f"inputs={string_type(inputs, with_shape=True)}")
print(f"ds={string_type(ds, with_shape=True)}")
Expand Down Expand Up @@ -65,7 +74,7 @@
verbose=1,
storage_limit=2**28,
):
model(*inputs)
expected = model(*inputs)


# %%
Expand Down Expand Up @@ -124,7 +133,74 @@
epo.save("plot_dump_intermediate_results.onnx")

# %%
# It looks like the following.
# Discrepancies
# +++++++++++++
#
# We have a torch model, intermediate results and an ONNX graph
# equivalent to the torch model.
# Let's see how we can check the discrepancies.
# First the discrepancies of the whole model.

sess = onnxruntime.InferenceSession(
"plot_dump_intermediate_results.onnx", providers=["CPUExecutionProvider"]
)
feeds = dict(
zip([i.name for i in sess.get_inputs()], [t.detach().cpu().numpy() for t in inputs])
)
got = sess.run(None, feeds)
diff = max_diff(expected, got)
print(f"discrepancies torch/ORT: {string_diff(diff)}")

# %%
# What about intermediate results?
# Let's use a runtime still based on :epkg:`onnxruntime`
# running an eager evaluation.

sess_eager = OnnxruntimeEvaluator(
"plot_dump_intermediate_results.onnx",
providers=["CPUExecutionProvider"],
torch_or_numpy=True,
)
feeds_tensor = dict(zip([i.name for i in sess.get_inputs()], inputs))
got = sess_eager.run(None, feeds_tensor)
diff = max_diff(expected, got)
print(f"discrepancies torch/eager ORT: {string_diff(diff)}")

# %%
# They are almost the same. That's good.
# Let's now dig into the intermediate results.
# They are compared to the outputs stored in saved_tensors
# during the execution of the model.
baseline = {}
for k, v in saved_tensors.items():
if k[-1] == "I": # inputs are excluded
continue
if isinstance(v, torch.Tensor):
baseline[f"{k[0]}.{k[1]}".replace("model.decoder", "decoder")] = v

report_cmp = ReportResultComparison(baseline)
sess_eager.run(None, feeds_tensor, report_cmp=report_cmp)

# %%
# Let's see the results.

data = report_cmp.data
df = pandas.DataFrame(data)
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
print(piv)

# %%
# Let's clean a little bit.
piv[piv >= 1] = np.nan
print(piv.dropna(axis=0, how="all"))

# %%
# We can identity which results is mapped to which expected tensor.

# %%
# Picture of the model
# ++++++++++++++++++++

onx = onnx.load("plot_dump_intermediate_results.onnx")
plot_dot(onx)

Expand Down
6 changes: 3 additions & 3 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,8 @@ The function replaces dynamic dimensions defined as strings by
Older versions
++++++++++++++

* `0.7.0 <../v0.7.0/index.html>`_
* `0.6.3 <../v0.6.3/index.html>`_
* `0.6.2 <../v0.6.2/index.html>`_
* `0.6.1 <../v0.6.1/index.html>`_
* `0.6.0 <../v0.6.0/index.html>`_
* `0.5.0 <../v0.5.0/index.html>`_
* `0.4.4 <../v0.4.4/index.html>`_
* `0.3.0 <../v0.3.0/index.html>`_
Expand All @@ -238,6 +236,7 @@ With the following versions:
import ml_dtypes
import sklearn
import onnx
import onnx_ir
import onnxruntime
import onnxscript
import torch
Expand All @@ -249,6 +248,7 @@ With the following versions:
ml_dtypes,
sklearn,
onnx,
onnx_ir,
onnxruntime,
onnxscript,
torch,
Expand Down
40 changes: 38 additions & 2 deletions _unittests/ut_reference/test_onnxruntime_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@
import onnx.helper as oh
import torch
import onnxruntime
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator
from onnx_diagnostic.reference import (
OnnxruntimeEvaluator,
ExtendedReferenceEvaluator,
ReportResultComparison,
)

try:
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
except ImportError:
to_onnx = None


TFLOAT = onnx.TensorProto.FLOAT


class TestOnnxruntimeEvaluator(ExtTestCase):
def test_ort_eval_scan_cdist_add(self):

Expand Down Expand Up @@ -190,6 +197,35 @@ def test_ort_eval_loop(self):
got = ref.run(None, feeds)
self.assertEqualArray(expected, got[0])

@hide_stdout()
def test_report_results_comparison_ort(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Cos", ["X"], ["nx"]),
oh.make_node("Sin", ["nx"], ["t"]),
oh.make_node("Exp", ["t"], ["u"]),
oh.make_node("Log", ["u"], ["uZ"]),
oh.make_node("Erf", ["uZ"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
),
ir_version=9,
opset_imports=[oh.make_opsetid("", 18)],
)
x = torch.rand(5, 6, dtype=torch.float32)
onnx.checker.check_model(model)
cmp = ReportResultComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
cmp.clear()
feeds = dict(zip([i.name for i in model.graph.input], (x,)))
rt = OnnxruntimeEvaluator(model, verbose=10)
rt.run(None, feeds, report_cmp=cmp)
d = {k: d["abs"] for k, d in cmp.value.items()}
self.assertLess(d[(0, "nx"), "r_cos"], 1e-6)
self.assertLess(d[(2, "u"), "r_exp"], 1e-6)


if __name__ == "__main__":
unittest.main(verbosity=2)
44 changes: 42 additions & 2 deletions _unittests/ut_reference/test_torch_onnx_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import unittest
import numpy as np
import pandas
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchOnnxEvaluator
from onnx_diagnostic.reference import (
ExtendedReferenceEvaluator,
TorchOnnxEvaluator,
ReportResultComparison,
)
from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
from onnx_diagnostic.reference.torch_evaluator import get_kernels

Expand Down Expand Up @@ -1471,6 +1476,41 @@ def run(self, x, scale, bias=None):
self.assertEqualAny(expected, got, atol=1e-3)
self.assertEqual([1], LayerNormalizationOrt._shared)

@hide_stdout()
def test_report_results_comparison(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Cos", ["X"], ["nx"]),
oh.make_node("Sin", ["nx"], ["t"]),
oh.make_node("Exp", ["t"], ["u"]),
oh.make_node("Log", ["u"], ["uZ"]),
oh.make_node("Erf", ["uZ"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
),
ir_version=9,
opset_imports=[oh.make_opsetid("", 18)],
)
x = torch.rand(5, 6, dtype=torch.float32)
onnx.checker.check_model(model)
cmp = ReportResultComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
cmp.clear()
feeds = dict(zip([i.name for i in model.graph.input], (x,)))
rt = TorchOnnxEvaluator(model, verbose=10)
rt.run(None, feeds, report_cmp=cmp)
d = {k: d["abs"] for k, d in cmp.value.items()}
self.assertEqual(d[(0, "nx"), "r_cos"], 0)
self.assertEqual(d[(2, "u"), "r_exp"], 0)
data = cmp.data
self.assertIsInstance(data, list)
df = pandas.DataFrame(data)
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
self.assertEqual(list(piv.columns), ["r_cos", "r_exp", "r_x"])
self.assertEqual(list(piv.index), [(0, "nx"), (1, "t"), (2, "u"), (3, "uZ"), (4, "Z")])


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion onnx_diagnostic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
Functions, classes to dig into a model when this one is right, slow, wrong...
"""

__version__ = "0.6.3"
__version__ = "0.7.0"
__author__ = "Xavier Dupré"
22 changes: 22 additions & 0 deletions onnx_diagnostic/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@
import numpy as np


def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
"""Returns the latest published version."""

import requests

url = f"https://pypi.org/pypi/{package_name}/json"
response = requests.get(url)

assert response.status_code == 200, f"Unable to retrieve the version response={response}"
data = response.json()
version = data["info"]["version"]
return version


def update_version_package(version: str, package_name="onnx-diagnostic") -> str:
"Adds dev if the major version is different from the latest published one."
released = get_latest_pypi_version(package_name)
shorten_r = ".".join(released.split(".")[:2])
shorten_v = ".".join(version.split(".")[:2])
return version if shorten_r == shorten_v else f"{shorten_v}.dev"


def reset_torch_transformers(gallery_conf, fname):
"Resets torch dynamo for :epkg:`sphinx-gallery`."
import matplotlib.pyplot as plt
Expand Down
1 change: 1 addition & 0 deletions onnx_diagnostic/reference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .evaluator import ExtendedReferenceEvaluator
from .ort_evaluator import OnnxruntimeEvaluator
from .torch_evaluator import TorchOnnxEvaluator
from .report_results_comparison import ReportResultComparison
Loading
Loading