Skip to content

Commit 4f620a2

Browse files
committed
Add OpRunValue
1 parent fe286f1 commit 4f620a2

File tree

9 files changed

+141
-15
lines changed

9 files changed

+141
-15
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.access_ops
3+
==============================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.access_ops
6+
:members:

_doc/api/reference/torch_ops/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ onnx_diagnostic.reference.torch_ops
77
:maxdepth: 1
88
:caption: modules
99

10+
access_ops
1011
binary_ops
1112

1213
OpRun
@@ -15,6 +16,12 @@ OpRun
1516
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRun
1617
:members:
1718

19+
OpRunValue
20+
++++++++++
21+
22+
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunValue
23+
:members:
24+
1825
Other functions
1926
+++++++++++++++
2027

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
TFLOAT = onnx.TensorProto.FLOAT
13+
TINT64 = onnx.TensorProto.INT64
1314

1415

1516
class TestTorchEvaluator(ExtTestCase):
@@ -72,6 +73,31 @@ def test_binary_ops(self):
7273
else:
7374
self.assertEmpty(v.value)
7475

76+
def test_slice_squeeze(self):
77+
X = oh.make_tensor_value_info("X", TFLOAT, [None, None])
78+
starts = oh.make_tensor_value_info("starts", TINT64, [None])
79+
ends = oh.make_tensor_value_info("ends", TINT64, [None])
80+
axes = oh.make_tensor_value_info("axes", TINT64, [None])
81+
Y = oh.make_tensor_value_info("Y", TINT64, [None])
82+
nodes = [
83+
oh.make_node("Slice", ["X", "starts", "ends", "axes"], ["T"]),
84+
oh.make_node("Squeeze", ["T", "axes"], ["Y"]),
85+
]
86+
graph = oh.make_graph(nodes, "g", [X, starts, ends, axes], [Y])
87+
model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", 18)])
88+
feeds = {
89+
"X": torch.tensor([[0]], dtype=torch.int64),
90+
"starts": torch.tensor([0], dtype=torch.int64),
91+
"ends": torch.tensor([1], dtype=torch.int64),
92+
"axes": torch.tensor([0], dtype=torch.int64),
93+
}
94+
expected = ExtendedReferenceEvaluator(model).run(
95+
None, {k: v.numpy() for k, v in feeds.items()}
96+
)
97+
rt = TorchEvaluator(model)
98+
got = rt.run(None, feeds)
99+
self.assertEqualAny(expected, [g.detach().numpy() for g in got])
100+
75101

76102
if __name__ == "__main__":
77103
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
125125
continue
126126
opset = self.opsets[node.domain]
127127
key = node.domain, node.op_type, opset
128-
while key not in kernels:
128+
while key not in kernels and opset > 0:
129129
opset -= 1
130130
key = node.domain, node.op_type, opset
131131
assert (
@@ -150,12 +150,20 @@ def run(
150150
for k, v in self.constants.items():
151151
r = self.runtime_info[k]
152152
if not r.has_value:
153-
r.set_value(v.to(self.CUDA) if r.is_shape and self.on_cuda else v)
153+
r.set_value(
154+
torch_ops.OpRunValue(
155+
v.to(self.CUDA) if r.is_shape and self.on_cuda else v, True
156+
)
157+
)
154158

155159
# inputs
156160
for k, v in feeds.items():
157161
r = self.runtime_info[k]
158-
r.set_value(v.to(self.CUDA) if r.is_shape and self.on_cuda else v)
162+
r.set_value(
163+
torch_ops.OpRunValue(
164+
v.to(self.CUDA) if r.is_shape and self.on_cuda else v, False
165+
)
166+
)
159167

160168
# node execution
161169
for it, kernel in enumerate(self.kernels):
@@ -174,7 +182,7 @@ def run(
174182
self.runtime_info[name].clean_value()
175183

176184
# outputs
177-
res = [self.runtime_info[o].value for o in outputs]
185+
res = [self.runtime_info[o].value.tensor for o in outputs]
178186

179187
# clean previous execution
180188
for k in feeds:
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from ._op_run import OpRun
2-
1+
from ._op_run import OpRun, OpRunValue
2+
from .access_ops import Slice_13
33
from .binary_ops import Add_1, Div_1, Mul_1, Sub_1
4+
from .shape_ops import Squeeze_13

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,44 @@
33
import torch
44

55

6+
class OpRunValue:
7+
"""
8+
Wrapper around the tensor.
9+
10+
:param tensor: torch.Tensor
11+
:param is_constant: is it a constant
12+
"""
13+
14+
__slots__ = ("cached", "is_constant", "tensor")
15+
16+
def __init__(self, tensor: torch.Tensor, is_constant: bool = False):
17+
self.tensor = tensor
18+
self.is_constant = is_constant
19+
self.cached = None
20+
21+
@property
22+
def shape(self):
23+
"shape"
24+
return self.tensor.shape
25+
26+
@property
27+
def dtype(self):
28+
"dtype"
29+
return self.tensor.dtype
30+
31+
def _tensor_as_tuple_int(self) -> Tuple[int, ...]:
32+
return tuple(map(int, self.tensor))
33+
34+
@property
35+
def as_tuple_int(self):
36+
"value as int"
37+
if self.is_constant:
38+
if self.cached is not None:
39+
self.cached = self._tensor_as_tuple_int()
40+
return self.cached
41+
return self._tensor_as_tuple_int()
42+
43+
644
class OpRun:
745
"""
846
Main class. Every kernel should inherit from it.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Optional
2+
from . import OpRun, OpRunValue
3+
4+
5+
class Slice_13(OpRun):
6+
"Slice"
7+
8+
def run(
9+
self,
10+
data: OpRunValue,
11+
starts: OpRunValue,
12+
ends: OpRunValue,
13+
axes: Optional[OpRunValue] = None,
14+
steps: Optional[OpRunValue] = None,
15+
) -> OpRunValue:
16+
if axes is None:
17+
if steps is None:
18+
slices = [slice(s, e) for s, e in zip(starts.tensor, ends.tensor)]
19+
else:
20+
slices = [
21+
slice(s, e, d) for s, e, d in zip(starts.tensor, ends.tensor, steps.tensor)
22+
]
23+
else:
24+
if steps is None:
25+
slices = [slice(0, a) for a in data.shape]
26+
for s, e, a in zip(starts.tensor, ends.tensor, axes.tensor):
27+
slices[a] = slice(s, e)
28+
else:
29+
slices = [slice(0, a) for a in data.shape]
30+
for s, e, a, d in zip(starts.tensor, ends.tensor, axes.tensor, steps.tensor):
31+
slices[a] = slice(s, e, d)
32+
return OpRunValue(data.tensor[tuple(slices)])
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
1-
from . import OpRun
1+
from . import OpRun, OpRunValue
22

33

44
class Add_1(OpRun):
55
"""Add"""
66

7-
def run(self, x, y):
8-
return x + y
7+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
8+
return OpRunValue(x.tensor + y.tensor)
99

1010

1111
class Mul_1(OpRun):
1212
"""Mul"""
1313

14-
def run(self, x, y):
15-
return x * y
14+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
15+
return OpRunValue(x.tensor * y.tensor)
1616

1717

1818
class Div_1(OpRun):
1919
"""Div"""
2020

21-
def run(self, x, y):
22-
return x / y
21+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
22+
return OpRunValue(x.tensor / y.tensor)
2323

2424

2525
class Sub_1(OpRun):
2626
"""Sub"""
2727

28-
def run(self, x, y):
29-
return x - y
28+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
29+
return OpRunValue(x.tensor - y.tensor)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from . import OpRun, OpRunValue
2+
3+
4+
class Squeeze_13(OpRun):
5+
"Squeeze"
6+
7+
def run(self, data: OpRunValue, axes: OpRunValue) -> OpRunValue:
8+
return OpRunValue(data.tensor.squeeze(axes.as_tuple_int))

0 commit comments

Comments
 (0)