Skip to content

Commit c3e2626

Browse files
committed
softmax
1 parent df6f2c0 commit c3e2626

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from typing import Optional
23
import numpy as np
34
import onnx
45
import onnx.helper as oh
@@ -22,7 +23,7 @@ def test_kernels(self):
2223
kernel = ker[key]
2324
self.assertEqual("Add_1", kernel.__name__)
2425

25-
def _finalize_test(self, model, *args):
26+
def _finalize_test(self, model, *args, atol: Optional[float] = None):
2627
onnx.checker.check_model(model)
2728
feeds = dict(zip([i.name for i in model.graph.input], args))
2829

@@ -31,7 +32,7 @@ def _finalize_test(self, model, *args):
3132
)
3233
rt = TorchEvaluator(model)
3334
got = rt.run(None, feeds)
34-
self.assertEqualAny(expected, [g.detach().numpy() for g in got])
35+
self.assertEqualAny(expected, [g.detach().numpy() for g in got], atol=atol)
3536

3637
def test_op_binary(self):
3738
model = oh.make_model(
@@ -260,6 +261,21 @@ def test_op_gather(self):
260261
torch.tensor([0, 1, 3], dtype=torch.int64),
261262
)
262263

264+
def test_op_softmax(self):
265+
model = oh.make_model(
266+
oh.make_graph(
267+
[oh.make_node("Softmax", ["X"], ["Z"], axis=0)],
268+
"dummy",
269+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"])],
270+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
271+
),
272+
ir_version=9,
273+
opset_imports=[oh.make_opsetid("", 18)],
274+
)
275+
self._finalize_test(
276+
model, torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)), atol=1e-6
277+
)
278+
263279

264280
if __name__ == "__main__":
265281
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from .access_ops import Gather_1, Slice_13
33
from .binary_ops import Add_1, Div_1, MatMul_1, Mul_1, Sub_1
44
from .other_ops import Cast_6, Concat_1, Transpose_1
5+
from .nn_ops import Softmax_13
56
from .shape_ops import Reshape_5, Shape_15, Squeeze_13, Unsqueeze_13
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Optional
2+
import onnx
3+
import torch
4+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5+
from . import OpRun, OpRunValue
6+
7+
8+
class Softmax_13(OpRun):
9+
"Softmax"
10+
11+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12+
super().__init__(node, version)
13+
self.axis = self.get_attribute_int(node, "axis", -1)
14+
assert isinstance(self.axis, int), f"Unexpected value for attribute axis={self.axis!r}"
15+
# this is out of spec
16+
stash_type = self.get_attribute_int(node, "stash_type", None)
17+
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
18+
19+
def run(self, data: OpRunValue) -> OpRunValue:
20+
return OpRunValue(
21+
torch.nn.functional.softmax(data.tensor, dim=self.axis, dtype=self.stash_type)
22+
)

0 commit comments

Comments
 (0)