Skip to content

Commit 118388a

Browse files
committed
add erf, tile
1 parent b76f66e commit 118388a

File tree

4 files changed

+38
-1
lines changed

4 files changed

+38
-1
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,8 @@ def test_op_unary(self):
671671
oh.make_node("Cos", ["X"], ["nx"]),
672672
oh.make_node("Sin", ["nx"], ["t"]),
673673
oh.make_node("Exp", ["t"], ["u"]),
674-
oh.make_node("Log", ["u"], ["Z"]),
674+
oh.make_node("Log", ["u"], ["uZ"]),
675+
oh.make_node("Erf", ["uZ"], ["Z"]),
675676
],
676677
"dummy",
677678
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
@@ -1242,6 +1243,26 @@ def test_averagepool_2d(self):
12421243
),
12431244
)
12441245

1246+
def test_tile(self):
1247+
model = oh.make_model(
1248+
oh.make_graph(
1249+
[oh.make_node("Tile", ["x", "repeat"], ["y"])],
1250+
"ut",
1251+
[
1252+
oh.make_tensor_value_info("x", TFLOAT, [None, None]),
1253+
oh.make_tensor_value_info("repeat", TFLOAT, [None]),
1254+
],
1255+
[oh.make_tensor_value_info("y", TFLOAT, [None, None, None, None])],
1256+
),
1257+
opset_imports=[oh.make_opsetid("", 18)],
1258+
ir_version=10,
1259+
)
1260+
self._finalize_test(
1261+
model,
1262+
torch.tensor([[1, 2], [3, 4]], dtype=torch.float32),
1263+
torch.tensor([2, 2], dtype=torch.int64),
1264+
)
1265+
12451266

12461267
if __name__ == "__main__":
12471268
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CastLike_15,
2424
NonZero_13,
2525
Concat_1,
26+
Tile_6,
2627
Transpose_1,
2728
Trilu_14,
2829
Where_9,
@@ -41,6 +42,7 @@
4142
from .unary_ops import (
4243
Abs_1,
4344
Cos_1,
45+
Erf_9,
4446
Exp_1,
4547
Identity_1,
4648
Log_1,

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ def run(self, x: OpRunTensor) -> OpRunTensor:
5252
return OpRunTensor(torch.nonzero(x.tensor).T)
5353

5454

55+
class Tile_6(OpRun):
56+
"Tile"
57+
58+
def run(self, x: OpRunTensor, repeat: OpRunTensor) -> OpRunTensor:
59+
return OpRunTensor(torch.tile(x.tensor, repeat.as_tuple_int))
60+
61+
5562
class Transpose_1(OpRun):
5663
"Transpose"
5764

onnx_diagnostic/reference/torch_ops/unary_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ def run(self, x: OpRunTensor) -> OpRunTensor:
1616
return OpRunTensor(x.tensor.cos())
1717

1818

19+
class Erf_9(OpRun):
20+
"""Erf"""
21+
22+
def run(self, x: OpRunTensor) -> OpRunTensor:
23+
return OpRunTensor(x.tensor.erf())
24+
25+
1926
class Exp_1(OpRun):
2027
"""Exp"""
2128

0 commit comments

Comments
 (0)