Skip to content

Commit b76f66e

Browse files
committed
add averagepool
1 parent a650468 commit b76f66e

File tree

3 files changed

+99
-3
lines changed

3 files changed

+99
-3
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,63 @@ def test_scatternd_2d(self):
11851185
use_ort=True,
11861186
)
11871187

1188+
def test_averagepool_1d(self):
1189+
model = oh.make_model(
1190+
oh.make_graph(
1191+
[
1192+
oh.make_node(
1193+
"AveragePool",
1194+
inputs=["x"],
1195+
outputs=["y"],
1196+
kernel_shape=[2],
1197+
)
1198+
],
1199+
"ut",
1200+
[oh.make_tensor_value_info("x", TFLOAT, [None, None, None])],
1201+
[oh.make_tensor_value_info("y", TFLOAT, [None, None, None])],
1202+
),
1203+
opset_imports=[oh.make_opsetid("", 18)],
1204+
ir_version=10,
1205+
)
1206+
self._finalize_test(model, torch.rand((1, 3, 32), dtype=torch.float32))
1207+
1208+
def test_averagepool_2d(self):
1209+
model = oh.make_model(
1210+
oh.make_graph(
1211+
[
1212+
oh.make_node(
1213+
"AveragePool",
1214+
inputs=["x"],
1215+
outputs=["y"],
1216+
kernel_shape=[5, 5],
1217+
pads=[2, 2, 2, 2],
1218+
)
1219+
],
1220+
"ut",
1221+
[oh.make_tensor_value_info("x", TFLOAT, [None, None, None, None])],
1222+
[oh.make_tensor_value_info("y", TFLOAT, [None, None, None, None])],
1223+
),
1224+
opset_imports=[oh.make_opsetid("", 18)],
1225+
ir_version=10,
1226+
)
1227+
self._finalize_test(
1228+
model,
1229+
torch.tensor(
1230+
[
1231+
[
1232+
[
1233+
[1, 2, 3, 4, 5],
1234+
[6, 7, 8, 9, 10],
1235+
[11, 12, 13, 14, 15],
1236+
[16, 17, 18, 19, 20],
1237+
[21, 22, 23, 24, 25],
1238+
]
1239+
]
1240+
],
1241+
dtype=torch.float32,
1242+
),
1243+
)
1244+
11881245

11891246
if __name__ == "__main__":
11901247
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from .controlflow_ops import If_1, Loop_16
1919
from .generator_ops import Range_11
20-
from .nn_ops import Conv_11, LayerNormalization_17, Softmax_13, Tanh_6
20+
from .nn_ops import AveragePool_11, Conv_11, LayerNormalization_17, Softmax_13, Tanh_6
2121
from .other_ops import (
2222
Cast_6,
2323
CastLike_15,

onnx_diagnostic/reference/torch_ops/nn_ops.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,45 @@
55
from . import OpRun, OpRunTensor
66

77

8+
class AveragePool_11(OpRun):
9+
"AveragePool"
10+
11+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12+
super().__init__(node, version)
13+
self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
14+
self.ceil_mode = bool(self.get_attribute_int(node, "ceil_mode", 0))
15+
self.count_include_pad = bool(self.get_attribute_int(node, "count_include_pad", 0))
16+
self.dilations = self.get_attribute_ints(node, "dilations", None)
17+
self.kernel_shape: Tuple[int, ...] = (
18+
self.get_attribute_ints(node, "kernel_shape") or tuple()
19+
)
20+
self.pads = self.get_attribute_ints(node, "pads", None)
21+
self.strides = self.get_attribute_ints(node, "strides", None)
22+
23+
def run(self, x):
24+
kernel_shape = self.kernel_shape
25+
dilations = self.dilations or [1 for _ in x.shape[2:]]
26+
strides = self.strides or [1 for _ in x.shape[2:]]
27+
pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
28+
assert (
29+
self.auto_pad == "NOTSET"
30+
), f"conv not implemented for auto_pad={self.auto_pad!r}"
31+
assert len(set(pads)) == 1, f"conv not implemented for pads={pads}"
32+
assert set(dilations) == {1}, f"conv not implemented for dilations={dilations}"
33+
avg_pool = getattr(torch.nn.functional, f"avg_pool{len(kernel_shape)}d")
34+
return OpRunTensor(
35+
avg_pool(
36+
x.tensor,
37+
kernel_size=tuple(kernel_shape),
38+
stride=tuple(strides),
39+
padding=pads[0],
40+
ceil_mode=self.ceil_mode,
41+
count_include_pad=self.count_include_pad,
42+
# dilation=tuple(dilations),
43+
)
44+
)
45+
46+
847
class Conv_11(OpRun):
948
"Conv"
1049

@@ -22,15 +61,15 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
2261
def run(self, x, w, b=None):
2362
kernel_shape = self.kernel_shape or w.shape[2:]
2463
assert (
25-
tuple(kernel_shape) == w.shape[2:]
64+
tuple(kernel_shape) == w.shape[-len(kernel_shape) :]
2665
), f"conv not implemented for kernel_shape={kernel_shape} and w.shape={w.shape}"
2766
dilations = self.dilations or [1 for _ in x.shape[2:]]
2867
strides = self.strides or [1 for _ in x.shape[2:]]
2968
pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
3069
assert (
3170
self.auto_pad == "NOTSET"
3271
), f"conv not implemented for auto_pad={self.auto_pad!r}"
33-
assert len(set(pads)) == 1, f"conv not implemented for dilations={pads}"
72+
assert len(set(pads)) == 1, f"conv not implemented for pads={pads}"
3473
if b is None:
3574
bias = None
3675
else:

0 commit comments

Comments
 (0)