|
4 | 4 | import onnx.helper as oh |
5 | 5 | import onnx.numpy_helper as onh |
6 | 6 | import torch |
7 | | -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings |
| 7 | +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout |
8 | 8 | from onnx_diagnostic.helpers.onnx_helper import from_array_extended |
9 | 9 | from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype |
10 | | -from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchOnnxEvaluator |
| 10 | +from onnx_diagnostic.reference import ( |
| 11 | + ExtendedReferenceEvaluator, |
| 12 | + TorchOnnxEvaluator, |
| 13 | + ReportResultsComparison, |
| 14 | +) |
11 | 15 | from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor |
12 | 16 | from onnx_diagnostic.reference.torch_evaluator import get_kernels |
13 | 17 |
|
@@ -1471,6 +1475,35 @@ def run(self, x, scale, bias=None): |
1471 | 1475 | self.assertEqualAny(expected, got, atol=1e-3) |
1472 | 1476 | self.assertEqual([1], LayerNormalizationOrt._shared) |
1473 | 1477 |
|
| 1478 | + @hide_stdout() |
| 1479 | + def test_report_results_comparison(self): |
| 1480 | + model = oh.make_model( |
| 1481 | + oh.make_graph( |
| 1482 | + [ |
| 1483 | + oh.make_node("Cos", ["X"], ["nx"]), |
| 1484 | + oh.make_node("Sin", ["nx"], ["t"]), |
| 1485 | + oh.make_node("Exp", ["t"], ["u"]), |
| 1486 | + oh.make_node("Log", ["u"], ["uZ"]), |
| 1487 | + oh.make_node("Erf", ["uZ"], ["Z"]), |
| 1488 | + ], |
| 1489 | + "dummy", |
| 1490 | + [oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])], |
| 1491 | + [oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])], |
| 1492 | + ), |
| 1493 | + ir_version=9, |
| 1494 | + opset_imports=[oh.make_opsetid("", 18)], |
| 1495 | + ) |
| 1496 | + x = torch.rand(5, 6, dtype=torch.float32) |
| 1497 | + onnx.checker.check_model(model) |
| 1498 | + cmp = ReportResultsComparison(dict(r_x=x, r_cos=x.cos(), r_exp=x.cos().sin().exp())) |
| 1499 | + cmp.clear() |
| 1500 | + feeds = dict(zip([i.name for i in model.graph.input], (x,))) |
| 1501 | + rt = TorchOnnxEvaluator(model, verbose=10) |
| 1502 | + rt.run(None, feeds, report_cmp=cmp) |
| 1503 | + d = {k: d["abs"] for k, d in cmp.value.items()} |
| 1504 | + self.assertEqual(d["nx", "r_cos"], 0) |
| 1505 | + self.assertEqual(d["u", "r_exp"], 0) |
| 1506 | + |
1474 | 1507 |
|
1475 | 1508 | if __name__ == "__main__": |
1476 | 1509 | unittest.main(verbosity=2) |
0 commit comments