|
4 | 4 | import onnx.helper as oh |
5 | 5 | import onnx.numpy_helper as onh |
6 | 6 | import torch |
7 | | -from onnx_diagnostic.ext_test_case import ExtTestCase |
| 7 | +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings |
8 | 8 | from onnx_diagnostic.helpers.onnx_helper import from_array_extended |
9 | 9 | from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchOnnxEvaluator |
10 | 10 | from onnx_diagnostic.reference.torch_evaluator import get_kernels |
@@ -801,6 +801,56 @@ def test_op_constant_of_shape(self): |
801 | 801 | onnx.checker.check_model(model) |
802 | 802 | self._finalize_test(model, torch.tensor([4, 5], dtype=torch.int64)) |
803 | 803 |
|
| 804 | + def test_op_trilu(self): |
| 805 | + model = oh.make_model( |
| 806 | + oh.make_graph( |
| 807 | + [oh.make_node("Trilu", ["X"], ["Z"])], |
| 808 | + "dummy", |
| 809 | + [oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])], |
| 810 | + [oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])], |
| 811 | + ), |
| 812 | + ir_version=9, |
| 813 | + opset_imports=[oh.make_opsetid("", 18)], |
| 814 | + ) |
| 815 | + onnx.checker.check_model(model) |
| 816 | + self._finalize_test(model, torch.rand((4, 4), dtype=torch.float32)) |
| 817 | + |
| 818 | + def test_op_trilu_1(self): |
| 819 | + model = oh.make_model( |
| 820 | + oh.make_graph( |
| 821 | + [oh.make_node("Trilu", ["X"], ["Z"], upper=0)], |
| 822 | + "dummy", |
| 823 | + [oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])], |
| 824 | + [oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])], |
| 825 | + ), |
| 826 | + ir_version=9, |
| 827 | + opset_imports=[oh.make_opsetid("", 18)], |
| 828 | + ) |
| 829 | + onnx.checker.check_model(model) |
| 830 | + self._finalize_test(model, torch.rand((4, 4), dtype=torch.float32)) |
| 831 | + |
| 832 | + @ignore_warnings(DeprecationWarning) |
| 833 | + def test_op_trilu_k(self): |
| 834 | + model = oh.make_model( |
| 835 | + oh.make_graph( |
| 836 | + [oh.make_node("Trilu", ["X", "k"], ["Z"], upper=1)], |
| 837 | + "dummy", |
| 838 | + [ |
| 839 | + oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]), |
| 840 | + oh.make_tensor_value_info("k", TINT64, []), |
| 841 | + ], |
| 842 | + [oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])], |
| 843 | + ), |
| 844 | + ir_version=9, |
| 845 | + opset_imports=[oh.make_opsetid("", 18)], |
| 846 | + ) |
| 847 | + onnx.checker.check_model(model) |
| 848 | + self._finalize_test( |
| 849 | + model, |
| 850 | + torch.rand((6, 6), dtype=torch.float32), |
| 851 | + torch.tensor([2], dtype=torch.int64), |
| 852 | + ) |
| 853 | + |
804 | 854 |
|
805 | 855 | if __name__ == "__main__": |
806 | 856 | unittest.main(verbosity=2) |
0 commit comments