Skip to content

Commit 62db406

Browse files
committed
better report
1 parent f0e43d4 commit 62db406

File tree

5 files changed

+86
-18
lines changed

5 files changed

+86
-18
lines changed

_doc/api/reference/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ OnnxruntimeEvaluator
3131
:members:
3232

3333
ReportResultsComparison
34-
++++++++++++++++++
34+
+++++++++++++++++++++++
3535

3636
.. autoclass:: onnx_diagnostic.reference.ReportResultsComparison
3737
:members:

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,23 @@
44
import onnx.helper as oh
55
import torch
66
import onnxruntime
7-
from onnx_diagnostic.ext_test_case import ExtTestCase
7+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
88
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
9-
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator
9+
from onnx_diagnostic.reference import (
10+
OnnxruntimeEvaluator,
11+
ExtendedReferenceEvaluator,
12+
ReportResultsComparison,
13+
)
1014

1115
try:
1216
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
1317
except ImportError:
1418
to_onnx = None
1519

1620

21+
TFLOAT = onnx.TensorProto.FLOAT
22+
23+
1724
class TestOnnxruntimeEvaluator(ExtTestCase):
1825
def test_ort_eval_scan_cdist_add(self):
1926

@@ -190,6 +197,35 @@ def test_ort_eval_loop(self):
190197
got = ref.run(None, feeds)
191198
self.assertEqualArray(expected, got[0])
192199

200+
@hide_stdout()
201+
def test_report_results_comparison_ort(self):
202+
model = oh.make_model(
203+
oh.make_graph(
204+
[
205+
oh.make_node("Cos", ["X"], ["nx"]),
206+
oh.make_node("Sin", ["nx"], ["t"]),
207+
oh.make_node("Exp", ["t"], ["u"]),
208+
oh.make_node("Log", ["u"], ["uZ"]),
209+
oh.make_node("Erf", ["uZ"], ["Z"]),
210+
],
211+
"dummy",
212+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
213+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
214+
),
215+
ir_version=9,
216+
opset_imports=[oh.make_opsetid("", 18)],
217+
)
218+
x = torch.rand(5, 6, dtype=torch.float32)
219+
onnx.checker.check_model(model)
220+
cmp = ReportResultsComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
221+
cmp.clear()
222+
feeds = dict(zip([i.name for i in model.graph.input], (x,)))
223+
rt = OnnxruntimeEvaluator(model, verbose=10)
224+
rt.run(None, feeds, report_cmp=cmp)
225+
d = {k: d["abs"] for k, d in cmp.value.items()}
226+
self.assertLess(d[(0, "nx"), "r_cos"], 1e-6)
227+
self.assertLess(d[(2, "u"), "r_exp"], 1e-6)
228+
193229

194230
if __name__ == "__main__":
195231
unittest.main(verbosity=2)

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import numpy as np
3+
import pandas
34
import onnx
45
import onnx.helper as oh
56
import onnx.numpy_helper as onh
@@ -1501,8 +1502,14 @@ def test_report_results_comparison(self):
15011502
rt = TorchOnnxEvaluator(model, verbose=10)
15021503
rt.run(None, feeds, report_cmp=cmp)
15031504
d = {k: d["abs"] for k, d in cmp.value.items()}
1504-
self.assertEqual(d["nx", "r_cos"], 0)
1505-
self.assertEqual(d["u", "r_exp"], 0)
1505+
self.assertEqual(d[(0, "nx"), "r_cos"], 0)
1506+
self.assertEqual(d[(2, "u"), "r_exp"], 0)
1507+
data = cmp.data
1508+
self.assertIsInstance(data, list)
1509+
df = pandas.DataFrame(data)
1510+
piv = df.pivot(index=("run_index", "run_name"), columns="ref_name", values="abs")
1511+
self.assertEqual(list(piv.columns), ["r_cos", "r_exp", "r_x"])
1512+
self.assertEqual(list(piv.index), [(0, "nx"), (1, "t"), (2, "u"), (3, "uZ"), (4, "Z")])
15061513

15071514

15081515
if __name__ == "__main__":

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
InferenceSessionForNumpy,
2323
_InferenceSession,
2424
)
25+
from .report_results_comparison import ReportResultsComparison
2526
from .evaluator import ExtendedReferenceEvaluator
2627

28+
2729
PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
2830
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
2931

@@ -214,16 +216,21 @@ def run(
214216
outputs: Optional[List[str]],
215217
feed_inputs: Dict[str, Any],
216218
intermediate: bool = False,
219+
report_cmp: Optional[ReportResultsComparison] = None,
217220
) -> Union[Dict[str, Any], List[Any]]:
218221
"""
219-
Runs the model.
220-
It only works with numpy arrays.
221-
222-
:param outputs: required outputs or None for all
223-
:param feed_inputs: inputs
224-
:param intermediate: returns all output instead of the last ones
222+
Runs the model.
223+
It only works with numpy arrays.
224+
225+
:param outputs: required outputs or None for all
226+
:param feed_inputs: inputs
227+
:param intermediate: returns all output instead of the last ones
228+
:param report_cmp: used as a reference,
229+
every intermediate results is compare to every existing one,
230+
if not empty, it is an instance of
231+
:class:`onnx_diagnostic.reference.ReportResultsComparison`
225232
:return: outputs, as a list if return_all is False,
226-
as a dictionary if return_all is True
233+
as a dictionary if return_all is True
227234
"""
228235
if self.rt_nodes_ is None:
229236
# runs a whole
@@ -267,6 +274,10 @@ def run(
267274
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
268275
assert isinstance(name, str), f"unexpected type for name {type(name)}"
269276
results[name] = value
277+
if report_cmp:
278+
reported = report_cmp.report(dict(zip(node.output, outputs)))
279+
if self.verbose > 1:
280+
print(f" -- report {len(reported)} comparisons")
270281
if not intermediate:
271282
self._clean_unused_inplace(i_node, node, results)
272283

onnx_diagnostic/reference/report_results_comparison.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple, Union
1+
from typing import Any, Dict, List, Tuple, Union
22

33

44
ReportKeyNameType = Union[str, Tuple[str, int, str]]
@@ -9,7 +9,7 @@ class ReportResultsComparison:
99
"""
1010
Holds tensors a runtime can use as a reference to compare
1111
intermediate results.
12-
See :meth:`onnx_diagnostic.reference.TorchOnnxEvalutor.run`.
12+
See :meth:`onnx_diagnostic.reference.TorchOnnxEvaluator.run`.
1313
1414
:param tensors: tensor
1515
"""
@@ -40,23 +40,37 @@ def _build_mapping(self):
4040
def clear(self):
4141
"""Clears the last report."""
4242
self.report_cmp = {}
43+
self.unique_run_names = set()
4344

4445
@property
4546
def value(self) -> Dict[Tuple[str, ReportKeyNameType], Dict[str, Union[float, str]]]:
4647
"Returns the report."
4748
return self.report_cmp
4849

50+
@property
51+
def data(self) -> List[Dict[str, Any]]:
52+
"Returns data which can be consumed by a dataframe."
53+
rows = []
54+
for k, v in self.value.items():
55+
(i_run, run_name), ref_name = k
56+
d = dict(run_index=i_run, run_name=run_name, ref_name=ref_name)
57+
d.update(v)
58+
rows.append(d)
59+
return rows
60+
4961
def report(
5062
self, outputs: Dict[str, "torch.Tensor"] # noqa: F821
51-
) -> List[Tuple[str, ReportKeyNameType, Dict[str, Union[float, str]]]]:
63+
) -> List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]]:
5264
"""
5365
For every tensor in outputs, compares it to every tensor held by
5466
this class if it shares the same type and shape. The function returns
5567
the results of the comparison. The function also collects the results
5668
into a dictionary the user can retrieve later.
5769
"""
58-
res: List[Tuple[str, ReportKeyNameType, Dict[str, Union[float, str]]]] = []
70+
res: List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]] = []
5971
for name, tensor in outputs.items():
72+
i_run = len(self.unique_run_names)
73+
self.unique_run_names.add(name)
6074
key = self.key(tensor)
6175
if key not in self.mapping:
6276
continue
@@ -71,6 +85,6 @@ def report(
7185
diff = self.max_diff(t, t2)
7286
else:
7387
diff = self.max_diff(tensor, t2)
74-
res.append((name, held_key, diff)) # type: ignore[arg-type]
75-
self.report_cmp[name, held_key] = diff
88+
res.append((i_run, name, held_key, diff)) # type: ignore[arg-type]
89+
self.report_cmp[(i_run, name), held_key] = diff
7690
return res

0 commit comments

Comments
 (0)