Skip to content

Commit a1cbe24

Browse files
committed
add trilu
1 parent 2d43905 commit a1cbe24

File tree

6 files changed

+69
-5
lines changed

6 files changed

+69
-5
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import onnx.helper as oh
55
import onnx.numpy_helper as onh
66
import torch
7-
from onnx_diagnostic.ext_test_case import ExtTestCase
7+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
88
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
99
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchOnnxEvaluator
1010
from onnx_diagnostic.reference.torch_evaluator import get_kernels
@@ -801,6 +801,56 @@ def test_op_constant_of_shape(self):
801801
onnx.checker.check_model(model)
802802
self._finalize_test(model, torch.tensor([4, 5], dtype=torch.int64))
803803

804+
def test_op_trilu(self):
805+
model = oh.make_model(
806+
oh.make_graph(
807+
[oh.make_node("Trilu", ["X"], ["Z"])],
808+
"dummy",
809+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
810+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
811+
),
812+
ir_version=9,
813+
opset_imports=[oh.make_opsetid("", 18)],
814+
)
815+
onnx.checker.check_model(model)
816+
self._finalize_test(model, torch.rand((4, 4), dtype=torch.float32))
817+
818+
def test_op_trilu_1(self):
819+
model = oh.make_model(
820+
oh.make_graph(
821+
[oh.make_node("Trilu", ["X"], ["Z"], upper=0)],
822+
"dummy",
823+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
824+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
825+
),
826+
ir_version=9,
827+
opset_imports=[oh.make_opsetid("", 18)],
828+
)
829+
onnx.checker.check_model(model)
830+
self._finalize_test(model, torch.rand((4, 4), dtype=torch.float32))
831+
832+
@ignore_warnings(DeprecationWarning)
833+
def test_op_trilu_k(self):
834+
model = oh.make_model(
835+
oh.make_graph(
836+
[oh.make_node("Trilu", ["X", "k"], ["Z"], upper=1)],
837+
"dummy",
838+
[
839+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
840+
oh.make_tensor_value_info("k", TINT64, []),
841+
],
842+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
843+
),
844+
ir_version=9,
845+
opset_imports=[oh.make_opsetid("", 18)],
846+
)
847+
onnx.checker.check_model(model)
848+
self._finalize_test(
849+
model,
850+
torch.rand((6, 6), dtype=torch.float32),
851+
torch.tensor([2], dtype=torch.int64),
852+
)
853+
804854

805855
if __name__ == "__main__":
806856
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
152152
), f"Missing kernel for node type {node.op_type!r} from domain {node.domain!r}"
153153
cls = kernels[key]
154154
if cls.device_dependent():
155-
kernel = cls(node, opset, self.default_device)
155+
kernel = cls(node, opset, self.default_device) # type: ignore[call-arg]
156156
else:
157157
kernel = cls(node, opset)
158158
self.kernels.append(kernel)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from .generator_ops import Range_11
1818
from .nn_ops import LayerNormalization_17, Softmax_13, Tanh_6
19-
from .other_ops import Cast_6, Concat_1, Transpose_1, Where_9
19+
from .other_ops import Cast_6, Concat_1, Transpose_1, Trilu_14, Where_9
2020
from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_18, ReduceSum_18
2121
from .shape_ops import (
2222
ConstantOfShape_9,

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class OpRun:
6161
@classmethod
6262
def device_dependent(cls) -> bool:
6363
"""
64-
Returns True if the kernel needs a device to be efficiently initiliazed.
64+
Returns True if the kernel needs a device to be efficiently initialized.
6565
"""
6666
return False
6767

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ def run(self, data: OpRunValue) -> OpRunValue:
4242
return OpRunValue(torch.permute(data.tensor, self.perm))
4343

4444

45+
class Trilu_14(OpRun):
46+
"Trilu"
47+
48+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
49+
super().__init__(node, version)
50+
self.upper = self.get_attribute_int(node, "upper", 1)
51+
52+
def run(self, data: OpRunValue, k: Optional[OpRunValue] = None) -> OpRunValue:
53+
diagonal = 0 if k is None else k.tensor.item()
54+
if self.upper:
55+
return OpRunValue(torch.triu(data.tensor, diagonal=diagonal))
56+
return OpRunValue(torch.tril(data.tensor, diagonal=diagonal))
57+
58+
4559
class Where_9(OpRun):
4660
"Where"
4761

onnx_diagnostic/reference/torch_ops/shape_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class ConstantOfShape_9(OpRun):
1010
@classmethod
1111
def device_dependent(cls) -> bool:
1212
"""
13-
Returns True if the kernel needs a device to be efficiently initiliazed.
13+
Returns True if the kernel needs a device to be efficiently initialized.
1414
"""
1515
return True
1616

0 commit comments

Comments
 (0)