Skip to content

Commit e5b2c4f

Browse files
authored
Add Shape op and other ops (#118)
* Add Shape op * shape * op * add concat * add gather * softmax * fix reshape * tanh * add equal * unary * reducemin * reduce * ind * disable * f * fix layer norm * mypy
1 parent 2750c6e commit e5b2c4f

File tree

13 files changed

+875
-11
lines changed

13 files changed

+875
-11
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 463 additions & 2 deletions
Large diffs are not rendered by default.

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_validate_model_export(self):
6969
self.assertIsInstance(summary, dict)
7070
self.assertIsInstance(data, dict)
7171

72-
@requires_torch("2.7")
72+
@requires_torch("2.8.99")
7373
@hide_stdout()
7474
@ignore_warnings(FutureWarning)
7575
def test_validate_model_onnx_dynamo_ir(self):

onnx_diagnostic/ext_test_case.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,11 +907,13 @@ def assertEqualArray(
907907
except AssertionError as e:
908908
expected_max = numpy.abs(expected).max()
909909
expected_value = numpy.abs(value).max()
910+
te = expected.astype(int) if expected.dtype == numpy.bool_ else expected
911+
tv = value.astype(int) if value.dtype == numpy.bool_ else value
910912
rows = [
911913
f"{msg}\n{e}" if msg else str(e),
912914
f"expected max value={expected_max}",
913915
f"expected computed value={expected_value}\n",
914-
f"ratio={expected / value}\ndiff={expected - value}",
916+
f"ratio={te / tv}\ndiff={te - tv}",
915917
]
916918
raise AssertionError("\n".join(rows)) # noqa: B904
917919

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,16 +172,23 @@ def run(
172172
inputs = [(self.runtime_info[i].value if i else None) for i in kernel.input]
173173
res = kernel.run(*inputs)
174174
if isinstance(res, tuple):
175+
# outputs
176+
assert all(isinstance(o, torch_ops.OpRunValue) for o in res), (
177+
f"Unexpected output type {[type(o) for o in res]} "
178+
f"for kernel {type(kernel)}."
179+
)
175180
for name, t in zip(kernel.output, res):
176181
self.runtime_info[name].set_value(t)
177182
else:
183+
assert isinstance(
184+
res, torch_ops.OpRunValue
185+
), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
178186
self.runtime_info[kernel.output[0]].set_value(res)
179187

180188
# free intermediate results
181189
for name in self.last_used[it]:
182190
self.runtime_info[name].clean_value()
183191

184-
# outputs
185192
res = [self.runtime_info[o].value.tensor for o in outputs] # type: ignore[assignment, union-attr]
186193

187194
# clean previous execution
Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
11
from ._op_run import OpRun, OpRunValue
2-
from .access_ops import Slice_13
3-
from .binary_ops import Add_1, Div_1, Mul_1, Sub_1
4-
from .shape_ops import Squeeze_13
2+
from .access_ops import Gather_1, Slice_13
3+
from .binary_ops import (
4+
And_1,
5+
Add_1,
6+
Div_1,
7+
Greater_1,
8+
GreaterOrEqual_1,
9+
Less_1,
10+
LessOrEqual_1,
11+
MatMul_1,
12+
Mul_1,
13+
Or_1,
14+
Sub_1,
15+
)
16+
from .nn_ops import LayerNormalization_17, Softmax_13, Tanh_6
17+
from .other_ops import Cast_6, Concat_1, Transpose_1, Where_9
18+
from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_18, ReduceSum_18
19+
from .shape_ops import Reshape_14, Shape_15, Squeeze_13, Unsqueeze_13
20+
from .unary_ops import Neg_1, Not_1, Reciprocal_1

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,51 @@ def run(self, *args) -> Union[OpRunValue, Tuple[OpRunValue, ...]]:
7373
raise NotImplementedError(
7474
f"Method run is not implemented for kernel {self.__class__.__name__!r}"
7575
)
76+
77+
def _find_attribute(self, node: onnx.NodeProto, name: str):
78+
for att in node.attribute:
79+
if att.name == name:
80+
return att
81+
return None
82+
83+
def get_attribute_float(
84+
self, node: onnx.NodeProto, name: str, default_value: Optional[float] = None
85+
) -> Optional[float]:
86+
"""
87+
Returns an attribute as an int.
88+
89+
:param node: NodeProto
90+
:param name: name
91+
:param default_value: default_value
92+
:return: value
93+
"""
94+
att = self._find_attribute(node, name)
95+
return default_value if att is None else float(att.f)
96+
97+
def get_attribute_int(
98+
self, node: onnx.NodeProto, name: str, default_value: Optional[int] = None
99+
) -> Optional[int]:
100+
"""
101+
Returns an attribute as an int.
102+
103+
:param node: NodeProto
104+
:param name: name
105+
:param default_value: default_value
106+
:return: value
107+
"""
108+
att = self._find_attribute(node, name)
109+
return default_value if att is None else int(att.i)
110+
111+
def get_attribute_ints(
112+
self, node: onnx.NodeProto, name: str, default_value: Optional[Tuple[int, ...]] = None
113+
) -> Optional[Tuple[int, ...]]:
114+
"""
115+
Returns an attribute as an int.
116+
117+
:param node: NodeProto
118+
:param name: name
119+
:param default_value: default_value
120+
:return: value
121+
"""
122+
att = self._find_attribute(node, name)
123+
return default_value if att is None else tuple(att.ints)

onnx_diagnostic/reference/torch_ops/access_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,26 @@
11
from typing import Optional
2+
import onnx
3+
import torch
24
from . import OpRun, OpRunValue
35

46

7+
class Gather_1(OpRun):
8+
"Gather"
9+
10+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
11+
super().__init__(node, version)
12+
axis = self.get_attribute_int(node, "axis", 0)
13+
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
14+
self.axis = axis
15+
16+
def run(self, x, indices):
17+
if indices.tensor.numel() == 0:
18+
return torch.empty((0,), dtype=x.tensor.dtype, device=x.tensor.device)
19+
ind = [slice(0, s) for s in x.shape]
20+
ind[self.axis] = indices.tensor
21+
return OpRunValue(x.tensor[tuple(ind)])
22+
23+
524
class Slice_13(OpRun):
625
"Slice"
726

onnx_diagnostic/reference/torch_ops/binary_ops.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,81 @@
11
from . import OpRun, OpRunValue
22

33

4+
class And_1(OpRun):
5+
"""And"""
6+
7+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
8+
return OpRunValue(x.tensor & y.tensor)
9+
10+
411
class Add_1(OpRun):
512
"""Add"""
613

714
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
815
return OpRunValue(x.tensor + y.tensor)
916

1017

18+
class Div_1(OpRun):
19+
"""Div"""
20+
21+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
22+
return OpRunValue(x.tensor / y.tensor)
23+
24+
25+
class Equal_1(OpRun):
26+
"""Equal"""
27+
28+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
29+
return OpRunValue(x.tensor == y.tensor)
30+
31+
32+
class Greater_1(OpRun):
33+
"""Greater"""
34+
35+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
36+
return OpRunValue(x.tensor > y.tensor)
37+
38+
39+
class GreaterOrEqual_1(OpRun):
40+
"""GreaterOrEqual"""
41+
42+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
43+
return OpRunValue(x.tensor >= y.tensor)
44+
45+
46+
class Less_1(OpRun):
47+
"""Less"""
48+
49+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
50+
return OpRunValue(x.tensor < y.tensor)
51+
52+
53+
class LessOrEqual_1(OpRun):
54+
"""LessOrEqual"""
55+
56+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
57+
return OpRunValue(x.tensor <= y.tensor)
58+
59+
60+
class MatMul_1(OpRun):
61+
"""MatMul"""
62+
63+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
64+
return OpRunValue(x.tensor @ y.tensor)
65+
66+
1167
class Mul_1(OpRun):
1268
"""Mul"""
1369

1470
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
1571
return OpRunValue(x.tensor * y.tensor)
1672

1773

18-
class Div_1(OpRun):
19-
"""Div"""
74+
class Or_1(OpRun):
75+
"""Or"""
2076

2177
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
22-
return OpRunValue(x.tensor / y.tensor)
78+
return OpRunValue(x.tensor | y.tensor)
2379

2480

2581
class Sub_1(OpRun):
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import Optional
2+
import onnx
3+
import torch
4+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5+
from . import OpRun, OpRunValue
6+
7+
8+
class LayerNormalization_17(OpRun):
9+
"LayerNormalization"
10+
11+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12+
super().__init__(node, version)
13+
self.axis = self.get_attribute_int(node, "axis", -1)
14+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
15+
self.stash_type = onnx_dtype_to_torch_dtype(
16+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
17+
)
18+
self.compute_std = len(node.output) > 1
19+
20+
def run(self, x, scale, bias=None):
21+
original_dtype = x.dtype
22+
xt = x.tensor.to(self.stash_type)
23+
res = torch.nn.functional.layer_norm(
24+
xt,
25+
xt.shape[self.axis :],
26+
weight=scale.tensor,
27+
bias=None if bias is None else bias.tensor,
28+
eps=self.epsilon,
29+
)
30+
if not self.compute_std:
31+
return OpRunValue(res.to(original_dtype))
32+
axes = tuple(range(len(xt.shape)))[self.axis :]
33+
mean, var = torch.var(xt, dim=axes, keepdim=False)
34+
x_inv_std_dev = torch.reciprocal(torch.sqrt(var + self.epsilon))
35+
return OpRunValue(res.to(original_dtype)), OpRunValue(mean), OpRunValue(x_inv_std_dev)
36+
37+
38+
class Softmax_13(OpRun):
39+
"Softmax"
40+
41+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
42+
super().__init__(node, version)
43+
self.axis = self.get_attribute_int(node, "axis", -1)
44+
assert isinstance(self.axis, int), f"Unexpected value for attribute axis={self.axis!r}"
45+
# this is out of spec
46+
stash_type = self.get_attribute_int(node, "stash_type", None)
47+
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
48+
49+
def run(self, data: OpRunValue) -> OpRunValue:
50+
return OpRunValue(
51+
torch.nn.functional.softmax(data.tensor, dim=self.axis, dtype=self.stash_type)
52+
)
53+
54+
55+
class Tanh_6(OpRun):
56+
"Tanh"
57+
58+
def run(self, data: OpRunValue) -> OpRunValue:
59+
return OpRunValue(torch.nn.functional.tanh(data.tensor))
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Optional
2+
import onnx
3+
import torch
4+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5+
from . import OpRun, OpRunValue
6+
7+
8+
class Cast_6(OpRun):
9+
"Cast"
10+
11+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12+
super().__init__(node, version)
13+
to = self.get_attribute_int(node, "to", 0)
14+
assert isinstance(to, int), f"Unexpected value for attribute to={to!r}"
15+
self.to = onnx_dtype_to_torch_dtype(to)
16+
17+
def run(self, data: OpRunValue) -> OpRunValue:
18+
return OpRunValue(data.tensor.to(self.to))
19+
20+
21+
class Concat_1(OpRun):
22+
"Concat"
23+
24+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
25+
super().__init__(node, version)
26+
axis = self.get_attribute_int(node, "axis", 0)
27+
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
28+
self.axis = axis
29+
30+
def run(self, *data: OpRunValue) -> OpRunValue:
31+
return OpRunValue(torch.cat([t.tensor for t in data], axis=self.axis))
32+
33+
34+
class Transpose_1(OpRun):
35+
"Transpose"
36+
37+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
38+
super().__init__(node, version)
39+
self.perm = self.get_attribute_ints(node, "perm", None)
40+
41+
def run(self, data: OpRunValue) -> OpRunValue:
42+
return OpRunValue(torch.permute(data.tensor, self.perm))
43+
44+
45+
class Where_9(OpRun):
46+
"Where"
47+
48+
def run(self, cond: OpRunValue, x: OpRunValue, y: OpRunValue) -> OpRunValue:
49+
return OpRunValue(torch.where(cond.tensor, x.tensor, y.tensor))

0 commit comments

Comments
 (0)