Skip to content

Commit 02185dd

Browse files
committed
tanh
1 parent 84f7fec commit 02185dd

File tree

4 files changed

+24
-2
lines changed

4 files changed

+24
-2
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,21 @@ def test_op_softmax(self):
297297
model, torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)), atol=1e-6
298298
)
299299

300+
def test_op_tanh(self):
301+
model = oh.make_model(
302+
oh.make_graph(
303+
[oh.make_node("Tanh", ["X"], ["Z"])],
304+
"dummy",
305+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"])],
306+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
307+
),
308+
ir_version=9,
309+
opset_imports=[oh.make_opsetid("", 18)],
310+
)
311+
self._finalize_test(
312+
model, torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)), atol=1e-6
313+
)
314+
300315

301316
if __name__ == "__main__":
302317
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
from .access_ops import Gather_1, Slice_13
33
from .binary_ops import Add_1, Div_1, MatMul_1, Mul_1, Sub_1
44
from .other_ops import Cast_6, Concat_1, Transpose_1
5-
from .nn_ops import Softmax_13
5+
from .nn_ops import Softmax_13, Tanh_6
66
from .shape_ops import Reshape_14, Shape_15, Squeeze_13, Unsqueeze_13

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_attribute_int(
9595
return default_value if att is None else int(att.i)
9696

9797
def get_attribute_ints(
98-
self, node: onnx.NodeProto, name: str, default_value: Optional[int] = None
98+
self, node: onnx.NodeProto, name: str, default_value: Optional[Tuple[int, ...]] = None
9999
) -> Optional[Tuple[int, ...]]:
100100
"""
101101
Returns an attribute as an int.

onnx_diagnostic/reference/torch_ops/nn_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@ def run(self, data: OpRunValue) -> OpRunValue:
2020
return OpRunValue(
2121
torch.nn.functional.softmax(data.tensor, dim=self.axis, dtype=self.stash_type)
2222
)
23+
24+
25+
class Tanh_6(OpRun):
26+
"Tanh"
27+
28+
def run(self, data: OpRunValue) -> OpRunValue:
29+
return OpRunValue(torch.nn.functional.tanh(data.tensor))

0 commit comments

Comments
 (0)