Skip to content

Commit 3e2237b

Browse files
committed
op
1 parent 04517dd commit 3e2237b

File tree

5 files changed

+69
-10
lines changed

5 files changed

+69
-10
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_op_cast(self):
147147
ir_version=9,
148148
opset_imports=[oh.make_opsetid("", 18)],
149149
)
150-
self._finalize_test(model, (torch.rand((4, 5, 6, 7), dtype=torch.float32),))
150+
self._finalize_test(model, torch.rand((4, 5, 6, 7), dtype=torch.float32))
151151

152152
def test_op_transpose(self):
153153
model = oh.make_model(
@@ -180,6 +180,46 @@ def test_op_reshape(self):
180180
model, torch.rand((4, 5, 6, 7), dtype=torch.float32), torch.tensor([7, 4, 6, 5])
181181
)
182182

183+
def test_op_matmul(self):
184+
model = oh.make_model(
185+
oh.make_graph(
186+
[oh.make_node("MatMul", ["X", "Y"], ["Z"])],
187+
"dummy",
188+
[
189+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c", "d"]),
190+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "d", "f"]),
191+
],
192+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c", "f"])],
193+
),
194+
ir_version=9,
195+
opset_imports=[oh.make_opsetid("", 18)],
196+
)
197+
self._finalize_test(
198+
model,
199+
torch.rand((4, 5, 6, 7), dtype=torch.float32),
200+
torch.rand((4, 5, 7, 11), dtype=torch.float32),
201+
)
202+
203+
def test_op_unsqueeze(self):
204+
model = oh.make_model(
205+
oh.make_graph(
206+
[oh.make_node("Unsqueeze", ["X", "axes"], ["Z"])],
207+
"dummy",
208+
[
209+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", 1, "d"]),
210+
oh.make_tensor_value_info("axes", TINT64, ["s"]),
211+
],
212+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "d"])],
213+
),
214+
ir_version=9,
215+
opset_imports=[oh.make_opsetid("", 18)],
216+
)
217+
self._finalize_test(
218+
model,
219+
torch.rand((4, 5, 1, 7), dtype=torch.float32),
220+
torch.tensor([2], dtype=torch.int64),
221+
)
222+
183223

184224
if __name__ == "__main__":
185225
unittest.main(verbosity=2)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._op_run import OpRun, OpRunValue
22
from .access_ops import Slice_13
3-
from .binary_ops import Add_1, Div_1, Mul_1, Sub_1
3+
from .binary_ops import Add_1, Div_1, MatMul_1, Mul_1, Sub_1
44
from .other_ops import Cast_6, Transpose_1
5-
from .shape_ops import Reshape_5, Shape_15, Squeeze_13
5+
from .shape_ops import Reshape_5, Shape_15, Squeeze_13, Unsqueeze_13

onnx_diagnostic/reference/torch_ops/binary_ops.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,25 @@ def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
88
return OpRunValue(x.tensor + y.tensor)
99

1010

11-
class Mul_1(OpRun):
12-
"""Mul"""
11+
class Div_1(OpRun):
12+
"""Div"""
1313

1414
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
15-
return OpRunValue(x.tensor * y.tensor)
15+
return OpRunValue(x.tensor / y.tensor)
1616

1717

18-
class Div_1(OpRun):
19-
"""Div"""
18+
class MatMul_1(OpRun):
19+
"""MatMul"""
2020

2121
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
22-
return OpRunValue(x.tensor / y.tensor)
22+
return OpRunValue(x.tensor @ y.tensor)
23+
24+
25+
class Mul_1(OpRun):
26+
"""Mul"""
27+
28+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
29+
return OpRunValue(x.tensor * y.tensor)
2330

2431

2532
class Sub_1(OpRun):

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
class Cast_6(OpRun):
99
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
1010
super().__init__(node, version)
11-
self.to = onnx_dtype_to_torch_dtype(self.get_attribute_int(node, "to", 0))
11+
to = self.get_attribute_int(node, "to", 0)
12+
assert isinstance(to, int), f"Unexpected value for attribute to={to!r}"
13+
self.to = onnx_dtype_to_torch_dtype(to)
1214

1315
def run(self, data: OpRunValue) -> OpRunValue:
1416
return OpRunValue(data.tensor.to(self.to))

onnx_diagnostic/reference/torch_ops/shape_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,13 @@ class Squeeze_13(OpRun):
2828

2929
def run(self, data: OpRunValue, axes: OpRunValue) -> OpRunValue:
3030
return OpRunValue(data.tensor.squeeze(axes.as_tuple_int))
31+
32+
33+
class Unsqueeze_13(OpRun):
34+
"Unsqueeze"
35+
36+
def run(self, data: OpRunValue, axes: OpRunValue) -> OpRunValue:
37+
t = data.tensor
38+
for i in axes.as_tuple_int:
39+
t = t.unsqueeze(i)
40+
return OpRunValue(t)

0 commit comments

Comments
 (0)