Skip to content

Commit 8a42d0e

Browse files
committed
update
1 parent aba66a0 commit 8a42d0e

File tree

4 files changed

+77
-3
lines changed

4 files changed

+77
-3
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,23 @@ def test_op_reduce_min(self):
435435
atol=1e-6,
436436
)
437437

438+
def test_op_reduce_min_17(self):
439+
model = oh.make_model(
440+
oh.make_graph(
441+
[oh.make_node("ReduceMin", ["X"], ["Z"], axes=[1])],
442+
"dummy",
443+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"])],
444+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
445+
),
446+
ir_version=9,
447+
opset_imports=[oh.make_opsetid("", 17)],
448+
)
449+
self._finalize_test(
450+
model,
451+
torch.rand(3, 4, 5, dtype=torch.float32),
452+
atol=1e-6,
453+
)
454+
438455
def test_op_reduce_sum(self):
439456
model = oh.make_model(
440457
oh.make_graph(
@@ -1104,6 +1121,21 @@ def test_conv(self):
11041121
B = torch.tensor([[[[0]]]], dtype=torch.float32)
11051122
self._finalize_test(model, X, W, B, use_ort=True)
11061123

1124+
def test_nonzero(self):
1125+
model = oh.make_model(
1126+
oh.make_graph(
1127+
[oh.make_node("NonZero", ["X"], ["Y"])],
1128+
"g",
1129+
[oh.make_tensor_value_info("X", TFLOAT, [None, None])],
1130+
[oh.make_tensor_value_info("Y", TINT64, [None, None])],
1131+
),
1132+
opset_imports=[oh.make_opsetid("", 18)],
1133+
)
1134+
1135+
self._finalize_test(
1136+
model, torch.tensor([[1, 0], [1, 1]], dtype=torch.float32), use_ort=True
1137+
)
1138+
11071139

11081140
if __name__ == "__main__":
11091141
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,16 @@
1818
from .controlflow_ops import If_1, Loop_16
1919
from .generator_ops import Range_11
2020
from .nn_ops import Conv_11, LayerNormalization_17, Softmax_13, Tanh_6
21-
from .other_ops import Cast_6, CastLike_15, Concat_1, Transpose_1, Trilu_14, Where_9
22-
from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_18, ReduceSum_18
21+
from .other_ops import (
22+
Cast_6,
23+
CastLike_15,
24+
NonZero_13,
25+
Concat_1,
26+
Transpose_1,
27+
Trilu_14,
28+
Where_9,
29+
)
30+
from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_17, ReduceMin_18, ReduceSum_13
2331
from .sequence_ops import ConcatFromSequence_11, SequenceEmpty_11, SequenceInsert_11
2432
from .shape_ops import (
2533
ConstantOfShape_9,

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def run(self, *data: OpRunTensor) -> OpRunTensor:
4545
return OpRunTensor(torch.cat([t.tensor for t in data], axis=self.axis))
4646

4747

48+
class NonZero_13(OpRun):
49+
"NonZero"
50+
51+
def run(self, x: OpRunTensor) -> OpRunTensor:
52+
return OpRunTensor(torch.nonzero(x.tensor).T)
53+
54+
4855
class Transpose_1(OpRun):
4956
"Transpose"
5057

onnx_diagnostic/reference/torch_ops/reduce_ops.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
2727
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
2828

2929

30+
class ReduceOpAxes(ReduceOp):
31+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
32+
super().__init__(node, version)
33+
self.axes = self.get_attribute_ints(node, "axes", [])
34+
35+
3036
class ReduceMax_18(ReduceOp):
3137
"""ReduceMax"""
3238

@@ -65,6 +71,27 @@ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor
6571
return OpRunTensor(t)
6672

6773

74+
class ReduceMin_17(ReduceOpAxes):
75+
"""ReduceMin"""
76+
77+
def run(self, x: OpRunTensor) -> OpRunTensor:
78+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
79+
axes = self.axes
80+
if not axes:
81+
assert (
82+
not self.keepdims
83+
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
84+
return OpRunTensor(torch.min(x.tensor).values)
85+
taxes = tuple(axes)
86+
if len(taxes) == 1:
87+
t = x.tensor.min(taxes[0], keepdim=self.keepdims)
88+
return OpRunTensor(t.values)
89+
t = x.tensor
90+
for a in reversed(taxes):
91+
t = t.min(a, keepdim=self.keepdims).values
92+
return OpRunTensor(t)
93+
94+
6895
class ReduceMin_18(ReduceOp):
6996
"""ReduceMin"""
7097

@@ -85,7 +112,7 @@ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor
85112
return OpRunTensor(t)
86113

87114

88-
class ReduceSum_18(ReduceOp):
115+
class ReduceSum_13(ReduceOp):
89116
"""ReduceSum"""
90117

91118
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:

0 commit comments

Comments
 (0)