Skip to content

Commit 6c88174

Browse files
committed
Add ReportResultsComparison
1 parent ad6e61c commit 6c88174

File tree

6 files changed

+147
-2
lines changed

6 files changed

+147
-2
lines changed

_doc/api/reference/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ onnx_diagnostic.reference
1515
evaluator
1616
quantized_tensor
1717
ort_evaluator
18+
report_results_comparison
1819
torch_evaluator
1920

2021
ExtendedReferenceEvaluator
@@ -29,6 +30,12 @@ OnnxruntimeEvaluator
2930
.. autoclass:: onnx_diagnostic.reference.OnnxruntimeEvaluator
3031
:members:
3132

33+
ReportResultsComparison
34+
++++++++++++++++++
35+
36+
.. autoclass:: onnx_diagnostic.reference.ReportResultsComparison
37+
:members:
38+
3239
TorchOnnxEvaluator
3340
++++++++++++++++++
3441

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.reference.report_results_comparison
3+
===================================================
4+
5+
.. automodule:: onnx_diagnostic.reference.report_results_comparison
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: ReportResultsComparison

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
import onnx.helper as oh
55
import onnx.numpy_helper as onh
66
import torch
7-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
7+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
88
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
99
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
10-
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchOnnxEvaluator
10+
from onnx_diagnostic.reference import (
11+
ExtendedReferenceEvaluator,
12+
TorchOnnxEvaluator,
13+
ReportResultsComparison,
14+
)
1115
from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
1216
from onnx_diagnostic.reference.torch_evaluator import get_kernels
1317

@@ -1471,6 +1475,35 @@ def run(self, x, scale, bias=None):
14711475
self.assertEqualAny(expected, got, atol=1e-3)
14721476
self.assertEqual([1], LayerNormalizationOrt._shared)
14731477

1478+
@hide_stdout()
1479+
def test_report_results_comparison(self):
1480+
model = oh.make_model(
1481+
oh.make_graph(
1482+
[
1483+
oh.make_node("Cos", ["X"], ["nx"]),
1484+
oh.make_node("Sin", ["nx"], ["t"]),
1485+
oh.make_node("Exp", ["t"], ["u"]),
1486+
oh.make_node("Log", ["u"], ["uZ"]),
1487+
oh.make_node("Erf", ["uZ"], ["Z"]),
1488+
],
1489+
"dummy",
1490+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
1491+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
1492+
),
1493+
ir_version=9,
1494+
opset_imports=[oh.make_opsetid("", 18)],
1495+
)
1496+
x = torch.rand(5, 6, dtype=torch.float32)
1497+
onnx.checker.check_model(model)
1498+
cmp = ReportResultsComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp()))
1499+
cmp.clear()
1500+
feeds = dict(zip([i.name for i in model.graph.input], (x,)))
1501+
rt = TorchOnnxEvaluator(model, verbose=10)
1502+
rt.run(None, feeds, report_cmp=cmp)
1503+
d = {k: d["abs"] for k, d in cmp.value.items()}
1504+
self.assertEqual(d["nx", "r_cos"], 0)
1505+
self.assertEqual(d["u", "r_exp"], 0)
1506+
14741507

14751508
if __name__ == "__main__":
14761509
unittest.main(verbosity=2)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .evaluator import ExtendedReferenceEvaluator
22
from .ort_evaluator import OnnxruntimeEvaluator
33
from .torch_evaluator import TorchOnnxEvaluator
4+
from .report_results_comparison import ReportResultsComparison
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Dict, List, Tuple, Union
2+
3+
4+
ReportKeyType = Union[str, Tuple[Union[int, str], ...]]
5+
6+
7+
class ReportResultsComparison:
8+
"""
9+
Holds tensors a runtime can use as a reference to compare
10+
intermediate results.
11+
See :meth:`onnx_diagnostic.reference.TorchOnnxEvalutor.run`.
12+
13+
:param tensors: tensor
14+
"""
15+
16+
def __init__(self, tensors: Dict[ReportKeyType, "torch.Tensor"]): # noqa: F821
17+
from ..helpers.onnx_helper import dtype_to_tensor_dtype
18+
from ..helpers import max_diff
19+
20+
self.dtype_to_tensor_dtype = dtype_to_tensor_dtype
21+
self.max_diff = max_diff
22+
self.tensors = tensors
23+
self._build_mapping()
24+
25+
def key(self, tensor: "torch.Tensor") -> ReportKeyType: # noqa: F821
26+
"Returns a key for a tensor, (onnx dtype, shape)."
27+
return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
28+
29+
def _build_mapping(self):
30+
mapping = {}
31+
for k, v in self.tensors.items():
32+
key = self.key(v)
33+
if key not in mapping:
34+
mapping[key] = []
35+
mapping[key].append(k)
36+
self.mapping = mapping
37+
self.clear()
38+
39+
def clear(self):
40+
"""Clears the last report."""
41+
self.report_cmp = {}
42+
43+
@property
44+
def value(self) -> Dict[Tuple[str, ReportKeyType], Dict[str, Union[float, str]]]:
45+
"Returns the report."
46+
return self.report_cmp
47+
48+
def report(
49+
self, outputs: Dict[str, "torch.Tensor"] # noqa: F821
50+
) -> List[Tuple[str, ReportKeyType]]:
51+
"""
52+
For every tensor in outputs, compares it to every tensor held by
53+
this class if it shares the same type and shape. The function returns
54+
the results of the comparison. The function also collects the results
55+
into a dictionary the user can retrieve later.
56+
"""
57+
res = []
58+
for name, tensor in outputs.items():
59+
key = self.key(tensor)
60+
if key not in self.mapping:
61+
continue
62+
cache = {}
63+
for held_key in self.mapping[key]:
64+
t2 = self.tensors[held_key]
65+
if hasattr(t2, "device") and hasattr(tensor, "device"):
66+
if t2.device in cache:
67+
t = cache[t2.device]
68+
else:
69+
cache[t2.device] = t = tensor.to(t2.device)
70+
diff = self.max_diff(t, t2)
71+
else:
72+
diff = self.max_diff(tensor, t2)
73+
res.append((name, held_key, diff))
74+
self.report_cmp[name, held_key] = diff
75+
return res

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from ..helpers.torch_helper import to_tensor
77
from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
8+
from .report_results_comparison import ReportResultsComparison
89
from . import torch_ops
910

1011

@@ -455,12 +456,17 @@ def run(
455456
self,
456457
outputs: Optional[List[str]],
457458
feeds: Union[Dict[str, torch.Tensor], Dict[str, np.ndarray]],
459+
report_cmp: Optional[ReportResultsComparison] = None,
458460
) -> Union[List[Optional[torch.Tensor]], List[Optional[np.ndarray]]]:
459461
"""
460462
Runs the ONNX model.
461463
462464
:param outputs: outputs required
463465
:param feeds: inputs
466+
:param report_cmp: used as a reference,
467+
every intermediate results is compare to every existing one,
468+
if not empty, it is an instance of
469+
:class:`onnx_diagnostic.reference.ReportResultsComparison`
464470
:return: output tensors.
465471
"""
466472
use_numpy = any(isinstance(t, np.ndarray) for t in feeds.values())
@@ -532,6 +538,21 @@ def run(
532538
f"+R {kernel.output[0]}: "
533539
f"{self.runtime_info[kernel.output[0]].string_type()}"
534540
)
541+
if report_cmp:
542+
reported = report_cmp.report(
543+
dict(
544+
zip(
545+
kernel.output,
546+
(
547+
tuple(r.tensor for r in res)
548+
if isinstance(res, tuple)
549+
else (res.tensor,)
550+
),
551+
)
552+
)
553+
)
554+
if self.verbose > 1:
555+
print(f" -- report {len(reported)} comparisons")
535556

536557
# free intermediate results
537558
for name in self.last_used[it]:

0 commit comments

Comments
 (0)