Skip to content

Commit f8afaf2

Browse files
committed
reducemin
1 parent b9640ec commit f8afaf2

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,33 @@ def test_op_tanh(self):
345345
model, torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)), atol=1e-6
346346
)
347347

348+
def test_op_reduce_min(self):
349+
model = oh.make_model(
350+
oh.make_graph(
351+
[oh.make_node("ReduceMin", ["X", "axes"], ["Z"])],
352+
"dummy",
353+
[
354+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
355+
oh.make_tensor_value_info("axes", TINT64, ["f"]),
356+
],
357+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
358+
),
359+
ir_version=9,
360+
opset_imports=[oh.make_opsetid("", 18)],
361+
)
362+
self._finalize_test(
363+
model,
364+
torch.rand(3, 4, 5, dtype=torch.float32),
365+
torch.tensor([1], dtype=torch.int64),
366+
atol=1e-6,
367+
)
368+
self._finalize_test(
369+
model,
370+
torch.rand(3, 4, 5, dtype=torch.float32),
371+
torch.tensor([1, 2], dtype=torch.int64),
372+
atol=1e-6,
373+
)
374+
348375

349376
if __name__ == "__main__":
350377
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
)
1616
from .other_ops import Cast_6, Concat_1, Transpose_1
1717
from .nn_ops import Softmax_13, Tanh_6
18+
from .reduce_ops import ReduceMin_18
1819
from .shape_ops import Reshape_14, Shape_15, Squeeze_13, Unsqueeze_13
1920
from .unary_ops import Neg_1, Not_1, Reciprocal_1
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Optional
2+
import onnx
3+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
4+
from . import OpRun, OpRunValue
5+
6+
7+
class ReduceOp(OpRun):
8+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
9+
super().__init__(node, version)
10+
self.keepdims = bool(self.get_attribute_int(node, "keepdims", 1))
11+
self.noop_with_empty_axes = bool(
12+
self.get_attribute_int(node, "noop_with_empty_axes", 0)
13+
)
14+
assert isinstance(
15+
self.keepdims, bool
16+
), f"Unexpected value for attribute keepdims={self.keepdims!r}"
17+
assert isinstance(self.noop_with_empty_axes, bool), (
18+
f"Unexpected value for attribute "
19+
f"noop_with_empty_axes={self.noop_with_empty_axes!r}"
20+
)
21+
assert (
22+
not self.noop_with_empty_axes
23+
), f"Not implemented with noop_with_empty_axes={self.noop_with_empty_axes}"
24+
# this is out of spec
25+
stash_type = self.get_attribute_int(node, "stash_type", None)
26+
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
27+
28+
29+
class ReduceMin_18(ReduceOp):
30+
"""ReduceMin"""
31+
32+
def run(self, x: OpRunValue, axes: OpRunValue) -> OpRunValue:
33+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
34+
taxes = axes.as_tuple_int
35+
if len(taxes) == 1:
36+
t = x.tensor.min(taxes[0], keepdim=self.keepdims)
37+
return OpRunValue(t.values)
38+
t = x.tensor
39+
for a in reversed(taxes):
40+
t = t.min(a, keepdim=self.keepdims).values
41+
return OpRunValue(t)

0 commit comments

Comments
 (0)