Skip to content

Commit 21ec1de

Browse files
committed
add equal
1 parent 02185dd commit 21ec1de

File tree

4 files changed

+95
-2
lines changed

4 files changed

+95
-2
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,35 @@ def test_op_binary(self):
8484
else:
8585
self.assertEmpty(v.value)
8686

87+
def test_op_binary_cmp(self):
88+
model = oh.make_model(
89+
oh.make_graph(
90+
[
91+
oh.make_node("Greater", ["X", "Y"], ["a"]),
92+
oh.make_node("GreaterOrEqual", ["X", "Y"], ["b"]),
93+
oh.make_node("Less", ["X", "Y"], ["c"]),
94+
oh.make_node("LessOrEqual", ["X", "Y"], ["d"]),
95+
oh.make_node("And", ["a", "b"], ["ab"]),
96+
oh.make_node("Or", ["c", "d"], ["cd"]),
97+
oh.make_node("And", ["ab", "cd"], ["Z"]),
98+
],
99+
"dummy",
100+
[
101+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
102+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b"]),
103+
],
104+
[oh.make_tensor_value_info("Z", onnx.TensorProto.BOOL, ["a", "b"])],
105+
),
106+
ir_version=9,
107+
opset_imports=[oh.make_opsetid("", 18)],
108+
)
109+
onnx.checker.check_model(model)
110+
self._finalize_test(
111+
model,
112+
torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)),
113+
torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)),
114+
)
115+
87116
def test_op_slice_squeeze(self):
88117
X = oh.make_tensor_value_info("X", TFLOAT, [None, None])
89118
starts = oh.make_tensor_value_info("starts", TINT64, [None])
@@ -220,6 +249,7 @@ def test_op_matmul(self):
220249
model,
221250
torch.rand((4, 5, 6, 7), dtype=torch.float32),
222251
torch.rand((4, 5, 7, 11), dtype=torch.float32),
252+
atol=1e-6,
223253
)
224254

225255
def test_op_unsqueeze(self):

onnx_diagnostic/ext_test_case.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,11 +907,13 @@ def assertEqualArray(
907907
except AssertionError as e:
908908
expected_max = numpy.abs(expected).max()
909909
expected_value = numpy.abs(value).max()
910+
te = expected.astype(int) if expected.dtype == numpy.bool_ else expected
911+
tv = value.astype(int) if value.dtype == numpy.bool_ else value
910912
rows = [
911913
f"{msg}\n{e}" if msg else str(e),
912914
f"expected max value={expected_max}",
913915
f"expected computed value={expected_value}\n",
914-
f"ratio={expected / value}\ndiff={expected - value}",
916+
f"ratio={te / tv}\ndiff={te - tv}",
915917
]
916918
raise AssertionError("\n".join(rows)) # noqa: B904
917919

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
from ._op_run import OpRun, OpRunValue
22
from .access_ops import Gather_1, Slice_13
3-
from .binary_ops import Add_1, Div_1, MatMul_1, Mul_1, Sub_1
3+
from .binary_ops import (
4+
And_1,
5+
Add_1,
6+
Div_1,
7+
Greater_1,
8+
GreaterOrEqual_1,
9+
Less_1,
10+
LessOrEqual_1,
11+
MatMul_1,
12+
Mul_1,
13+
Or_1,
14+
Sub_1,
15+
)
416
from .other_ops import Cast_6, Concat_1, Transpose_1
517
from .nn_ops import Softmax_13, Tanh_6
618
from .shape_ops import Reshape_14, Shape_15, Squeeze_13, Unsqueeze_13

onnx_diagnostic/reference/torch_ops/binary_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from . import OpRun, OpRunValue
22

33

4+
class And_1(OpRun):
5+
"""And"""
6+
7+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
8+
return OpRunValue(x.tensor & y.tensor)
9+
10+
411
class Add_1(OpRun):
512
"""Add"""
613

@@ -15,6 +22,41 @@ def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
1522
return OpRunValue(x.tensor / y.tensor)
1623

1724

25+
class Equal_1(OpRun):
26+
"""Equal"""
27+
28+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
29+
return OpRunValue(x.tensor == y.tensor)
30+
31+
32+
class Greater_1(OpRun):
33+
"""Greater"""
34+
35+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
36+
return OpRunValue(x.tensor > y.tensor)
37+
38+
39+
class GreaterOrEqual_1(OpRun):
40+
"""GreaterOrEqual"""
41+
42+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
43+
return OpRunValue(x.tensor >= y.tensor)
44+
45+
46+
class Less_1(OpRun):
47+
"""Less"""
48+
49+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
50+
return OpRunValue(x.tensor < y.tensor)
51+
52+
53+
class LessOrEqual_1(OpRun):
54+
"""LessOrEqual"""
55+
56+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
57+
return OpRunValue(x.tensor <= y.tensor)
58+
59+
1860
class MatMul_1(OpRun):
1961
"""MatMul"""
2062

@@ -29,6 +71,13 @@ def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
2971
return OpRunValue(x.tensor * y.tensor)
3072

3173

74+
class Or_1(OpRun):
75+
"""Or"""
76+
77+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
78+
return OpRunValue(x.tensor | y.tensor)
79+
80+
3281
class Sub_1(OpRun):
3382
"""Sub"""
3483

0 commit comments

Comments
 (0)