Skip to content

Commit 21ee9fb

Browse files
committed
conv
1 parent 1e1656f commit 21ee9fb

File tree

4 files changed

+94
-2
lines changed

4 files changed

+94
-2
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,41 @@ def test_loop(self):
10691069
model, torch.tensor(5, dtype=torch.int64), torch.tensor(1, dtype=torch.bool)
10701070
)
10711071

1072+
def test_conv(self):
1073+
model = oh.make_model(
1074+
oh.make_graph(
1075+
[
1076+
oh.make_node(
1077+
"Conv",
1078+
["X", "W", "B"],
1079+
["Y"],
1080+
pads=[1, 1, 1, 1],
1081+
dilations=[1, 1],
1082+
strides=[2, 2],
1083+
)
1084+
],
1085+
"g",
1086+
[
1087+
oh.make_tensor_value_info("X", TFLOAT, [None, None, None, None]),
1088+
oh.make_tensor_value_info("W", TFLOAT, [None, None, None, None]),
1089+
oh.make_tensor_value_info("B", TFLOAT, [None, None, None, None]),
1090+
],
1091+
[oh.make_tensor_value_info("Y", TFLOAT, [None, None, None, None])],
1092+
),
1093+
opset_imports=[oh.make_opsetid("", 18)],
1094+
)
1095+
sH, sW = 5, 6
1096+
i = sH // 2
1097+
j = sW // 2
1098+
X = torch.zeros((1, 1, sH, sW), dtype=torch.float32)
1099+
X[0, 0, i, j] = 1.0
1100+
W = torch.zeros((1, 1, 3, 3), dtype=torch.float32)
1101+
W[0, 0, :, :] = torch.minimum(
1102+
2 ** torch.arange(9).reshape((3, -1)), torch.tensor([256])
1103+
)
1104+
B = torch.tensor([[[[0]]]], dtype=torch.float32)
1105+
self._finalize_test(model, X, W, B, use_ort=True)
1106+
10721107

10731108
if __name__ == "__main__":
10741109
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from .controlflow_ops import If_1, Loop_16
1919
from .generator_ops import Range_11
20-
from .nn_ops import LayerNormalization_17, Softmax_13, Tanh_6
20+
from .nn_ops import Conv_11, LayerNormalization_17, Softmax_13, Tanh_6
2121
from .other_ops import Cast_6, CastLike_15, Concat_1, Transpose_1, Trilu_14, Where_9
2222
from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_18, ReduceSum_18
2323
from .sequence_ops import ConcatFromSequence_11, SequenceEmpty_11, SequenceInsert_11

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,21 @@ def get_attribute_ints(
261261
:return: value
262262
"""
263263
att = self._find_attribute(node, name)
264-
return default_value if att is None else tuple(att.ints)
264+
return default_value if att is None else tuple(map(int, att.ints))
265+
266+
def get_attribute_string(
267+
self, node: onnx.NodeProto, name: str, default_value: Optional[str] = None
268+
) -> Optional[Tuple[int, ...]]:
269+
"""
270+
Returns an attribute as a tuple of ints.
271+
272+
:param node: NodeProto
273+
:param name: name
274+
:param default_value: default_value
275+
:return: value
276+
"""
277+
att = self._find_attribute(node, name)
278+
return default_value if att is None else att.s.decode("utf-8")
265279

266280
def get_attribute_tensor(self, node: onnx.NodeProto, name: str) -> Optional[torch.Tensor]:
267281
"""

onnx_diagnostic/reference/torch_ops/nn_ops.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,49 @@
55
from . import OpRun, OpRunTensor
66

77

8+
class Conv_11(OpRun):
9+
"Conv"
10+
11+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12+
super().__init__(node, version)
13+
self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
14+
self.dilations = self.get_attribute_ints(node, "dilations", None)
15+
self.group = self.get_attribute_int(node, "group", 1)
16+
self.kernel_shape = self.get_attribute_int(node, "kernel_shape", [])
17+
self.pads = self.get_attribute_ints(node, "pads", None)
18+
self.strides = self.get_attribute_ints(node, "strides", None)
19+
20+
def run(self, x, w, b=None):
21+
kernel_shape = self.kernel_shape or w.shape[2:]
22+
assert (
23+
tuple(kernel_shape) == w.shape[2:]
24+
), f"conv not implemented for kernel_shape={kernel_shape} and w.shape={w.shape}"
25+
dilations = self.dilations or [1 for _ in x.shape[2:]]
26+
strides = self.strides or [1 for _ in x.shape[2:]]
27+
pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
28+
assert (
29+
self.auto_pad == "NOTSET"
30+
), f"conv not implemented for auto_pad={self.auto_pad!r}"
31+
assert len(set(pads)) == 1, f"conv not implemented for dilations={pads}"
32+
if b is None:
33+
bias = None
34+
else:
35+
bias = b.tensor.squeeze()
36+
if not bias.shape:
37+
bias = bias.unsqueeze(0)
38+
return OpRunTensor(
39+
torch.nn.functional.conv2d(
40+
x.tensor,
41+
w.tensor,
42+
bias=bias,
43+
stride=tuple(strides),
44+
padding=pads[0],
45+
dilation=tuple(dilations),
46+
groups=self.group,
47+
)
48+
)
49+
50+
851
class LayerNormalization_17(OpRun):
952
"LayerNormalization"
1053

0 commit comments

Comments
 (0)