Skip to content

Commit 84f7fec

Browse files
committed
fix reshape
1 parent c3e2626 commit 84f7fec

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import unittest
2-
from typing import Optional
32
import numpy as np
43
import onnx
54
import onnx.helper as oh
@@ -23,7 +22,7 @@ def test_kernels(self):
2322
kernel = ker[key]
2423
self.assertEqual("Add_1", kernel.__name__)
2524

26-
def _finalize_test(self, model, *args, atol: Optional[float] = None):
25+
def _finalize_test(self, model, *args, atol: float = 0):
2726
onnx.checker.check_model(model)
2827
feeds = dict(zip([i.name for i in model.graph.input], args))
2928

@@ -178,7 +177,29 @@ def test_op_reshape(self):
178177
opset_imports=[oh.make_opsetid("", 18)],
179178
)
180179
self._finalize_test(
181-
model, torch.rand((4, 5, 6, 7), dtype=torch.float32), torch.tensor([7, 4, 6, 5])
180+
model,
181+
torch.rand((4, 5, 6, 7), dtype=torch.float32),
182+
torch.tensor([7, 4, 6, 5], dtype=torch.int64),
183+
)
184+
185+
def test_op_reshape_zero(self):
186+
model = oh.make_model(
187+
oh.make_graph(
188+
[oh.make_node("Reshape", ["X", "shape"], ["Y"])],
189+
"dummy",
190+
[
191+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c", "d"]),
192+
oh.make_tensor_value_info("shape", TINT64, ["f"]),
193+
],
194+
[oh.make_tensor_value_info("Y", TFLOAT, ["d", "a", "c", "b"])],
195+
),
196+
ir_version=9,
197+
opset_imports=[oh.make_opsetid("", 18)],
198+
)
199+
self._finalize_test(
200+
model,
201+
torch.rand((4, 5, 6, 7), dtype=torch.float32),
202+
torch.tensor([7, 4, 0, 5], dtype=torch.int64),
182203
)
183204

184205
def test_op_matmul(self):

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from .binary_ops import Add_1, Div_1, MatMul_1, Mul_1, Sub_1
44
from .other_ops import Cast_6, Concat_1, Transpose_1
55
from .nn_ops import Softmax_13
6-
from .shape_ops import Reshape_5, Shape_15, Squeeze_13, Unsqueeze_13
6+
from .shape_ops import Reshape_14, Shape_15, Squeeze_13, Unsqueeze_13

onnx_diagnostic/reference/torch_ops/shape_ops.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,21 @@
44
from . import OpRun, OpRunValue
55

66

7-
class Reshape_5(OpRun):
7+
class Reshape_14(OpRun):
88
"Reshape"
99

10+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
11+
super().__init__(node, version)
12+
self.allowzero = self.get_attribute_int(node, "allowzero", 0)
13+
1014
def run(self, data: OpRunValue, shape: OpRunValue) -> OpRunValue:
15+
ishape = shape.as_tuple_int
16+
if not self.allowzero and 0 in ishape:
17+
xshape = data.tensor.shape
18+
new_shape = []
19+
for i, s in enumerate(ishape):
20+
new_shape.append(xshape[i] if s == 0 else s)
21+
return OpRunValue(data.tensor.reshape(new_shape))
1122
return OpRunValue(data.tensor.reshape(shape.as_tuple_int))
1223

1324

0 commit comments

Comments
 (0)