Skip to content

Commit 8a78df2

Browse files
committed
reduce
1 parent f8afaf2 commit 8a78df2

File tree

3 files changed

+123
-1
lines changed

3 files changed

+123
-1
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,60 @@ def test_op_tanh(self):
345345
model, torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)), atol=1e-6
346346
)
347347

348+
def test_op_reduce_max(self):
349+
model = oh.make_model(
350+
oh.make_graph(
351+
[oh.make_node("ReduceMax", ["X", "axes"], ["Z"])],
352+
"dummy",
353+
[
354+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
355+
oh.make_tensor_value_info("axes", TINT64, ["f"]),
356+
],
357+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
358+
),
359+
ir_version=9,
360+
opset_imports=[oh.make_opsetid("", 18)],
361+
)
362+
self._finalize_test(
363+
model,
364+
torch.rand(3, 4, 5, dtype=torch.float32),
365+
torch.tensor([1], dtype=torch.int64),
366+
atol=1e-6,
367+
)
368+
self._finalize_test(
369+
model,
370+
torch.rand(3, 4, 5, dtype=torch.float32),
371+
torch.tensor([1, 2], dtype=torch.int64),
372+
atol=1e-6,
373+
)
374+
375+
def test_op_reduce_mean(self):
376+
model = oh.make_model(
377+
oh.make_graph(
378+
[oh.make_node("ReduceMean", ["X", "axes"], ["Z"])],
379+
"dummy",
380+
[
381+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
382+
oh.make_tensor_value_info("axes", TINT64, ["f"]),
383+
],
384+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
385+
),
386+
ir_version=9,
387+
opset_imports=[oh.make_opsetid("", 18)],
388+
)
389+
self._finalize_test(
390+
model,
391+
torch.rand(3, 4, 5, dtype=torch.float32),
392+
torch.tensor([1], dtype=torch.int64),
393+
atol=1e-6,
394+
)
395+
self._finalize_test(
396+
model,
397+
torch.rand(3, 4, 5, dtype=torch.float32),
398+
torch.tensor([1, 2], dtype=torch.int64),
399+
atol=1e-6,
400+
)
401+
348402
def test_op_reduce_min(self):
349403
model = oh.make_model(
350404
oh.make_graph(
@@ -372,6 +426,33 @@ def test_op_reduce_min(self):
372426
atol=1e-6,
373427
)
374428

429+
def test_op_reduce_sum(self):
430+
model = oh.make_model(
431+
oh.make_graph(
432+
[oh.make_node("ReduceSum", ["X", "axes"], ["Z"])],
433+
"dummy",
434+
[
435+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
436+
oh.make_tensor_value_info("axes", TINT64, ["f"]),
437+
],
438+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
439+
),
440+
ir_version=9,
441+
opset_imports=[oh.make_opsetid("", 18)],
442+
)
443+
self._finalize_test(
444+
model,
445+
torch.rand(3, 4, 5, dtype=torch.float32),
446+
torch.tensor([1], dtype=torch.int64),
447+
atol=1e-6,
448+
)
449+
self._finalize_test(
450+
model,
451+
torch.rand(3, 4, 5, dtype=torch.float32),
452+
torch.tensor([1, 2], dtype=torch.int64),
453+
atol=1e-6,
454+
)
455+
375456

376457
if __name__ == "__main__":
377458
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
)
1616
from .other_ops import Cast_6, Concat_1, Transpose_1
1717
from .nn_ops import Softmax_13, Tanh_6
18-
from .reduce_ops import ReduceMin_18
18+
from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_18, ReduceSum_18
1919
from .shape_ops import Reshape_14, Shape_15, Squeeze_13, Unsqueeze_13
2020
from .unary_ops import Neg_1, Not_1, Reciprocal_1

onnx_diagnostic/reference/torch_ops/reduce_ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,34 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
2626
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
2727

2828

29+
class ReduceMax_18(ReduceOp):
30+
"""ReduceMax"""
31+
32+
def run(self, x: OpRunValue, axes: OpRunValue) -> OpRunValue:
33+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
34+
taxes = axes.as_tuple_int
35+
if len(taxes) == 1:
36+
t = x.tensor.max(taxes[0], keepdim=self.keepdims)
37+
return OpRunValue(t.values)
38+
t = x.tensor
39+
for a in reversed(taxes):
40+
t = t.max(a, keepdim=self.keepdims).values
41+
return OpRunValue(t)
42+
43+
44+
class ReduceMean_18(ReduceOp):
45+
"""ReduceMean"""
46+
47+
def run(self, x: OpRunValue, axes: OpRunValue) -> OpRunValue:
48+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
49+
taxes = axes.as_tuple_int
50+
if len(taxes) == 1:
51+
t = x.tensor.mean(taxes[0], keepdim=self.keepdims)
52+
return OpRunValue(t)
53+
t = x.tensor.mean(taxes, keepdim=self.keepdims)
54+
return OpRunValue(t)
55+
56+
2957
class ReduceMin_18(ReduceOp):
3058
"""ReduceMin"""
3159

@@ -39,3 +67,16 @@ def run(self, x: OpRunValue, axes: OpRunValue) -> OpRunValue:
3967
for a in reversed(taxes):
4068
t = t.min(a, keepdim=self.keepdims).values
4169
return OpRunValue(t)
70+
71+
72+
class ReduceSum_18(ReduceOp):
73+
"""ReduceSum"""
74+
75+
def run(self, x: OpRunValue, axes: OpRunValue) -> OpRunValue:
76+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
77+
taxes = axes.as_tuple_int
78+
if len(taxes) == 1:
79+
t = x.tensor.sum(taxes[0], keepdim=self.keepdims)
80+
return OpRunValue(t)
81+
t = x.tensor.sum(taxes, keepdim=self.keepdims)
82+
return OpRunValue(t)

0 commit comments

Comments
 (0)