Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Change Logs
0.6.1
+++++

* :pr:`115`, :pr:`116`: first steps for TorchEvaluator
* :pr:`115`, :pr:`116`, :pr:`117`: first steps for TorchEvaluator
* :pr:`114`: extends the list of known rewritings
* :pr:`113`: fixes a couple of issues with ModelBuilder

Expand Down
6 changes: 6 additions & 0 deletions _doc/api/reference/torch_ops/access_ops.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

onnx_diagnostic.reference.torch_ops.access_ops
==============================================

.. automodule:: onnx_diagnostic.reference.torch_ops.access_ops
:members:
7 changes: 7 additions & 0 deletions _doc/api/reference/torch_ops/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ onnx_diagnostic.reference.torch_ops
:maxdepth: 1
:caption: modules

access_ops
binary_ops

OpRun
Expand All @@ -15,6 +16,12 @@ OpRun
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRun
:members:

OpRunValue
++++++++++

.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunValue
:members:

Other functions
+++++++++++++++

Expand Down
26 changes: 26 additions & 0 deletions _unittests/ut_reference/test_torch_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


TFLOAT = onnx.TensorProto.FLOAT
TINT64 = onnx.TensorProto.INT64


class TestTorchEvaluator(ExtTestCase):
Expand Down Expand Up @@ -72,6 +73,31 @@ def test_binary_ops(self):
else:
self.assertEmpty(v.value)

def test_slice_squeeze(self):
X = oh.make_tensor_value_info("X", TFLOAT, [None, None])
starts = oh.make_tensor_value_info("starts", TINT64, [None])
ends = oh.make_tensor_value_info("ends", TINT64, [None])
axes = oh.make_tensor_value_info("axes", TINT64, [None])
Y = oh.make_tensor_value_info("Y", TINT64, [None])
nodes = [
oh.make_node("Slice", ["X", "starts", "ends", "axes"], ["T"]),
oh.make_node("Squeeze", ["T", "axes"], ["Y"]),
]
graph = oh.make_graph(nodes, "g", [X, starts, ends, axes], [Y])
model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", 18)])
feeds = {
"X": torch.tensor([[0]], dtype=torch.int64),
"starts": torch.tensor([0], dtype=torch.int64),
"ends": torch.tensor([1], dtype=torch.int64),
"axes": torch.tensor([0], dtype=torch.int64),
}
expected = ExtendedReferenceEvaluator(model).run(
None, {k: v.numpy() for k, v in feeds.items()}
)
rt = TorchEvaluator(model)
got = rt.run(None, feeds)
self.assertEqualAny(expected, [g.detach().numpy() for g in got])


if __name__ == "__main__":
unittest.main(verbosity=2)
20 changes: 14 additions & 6 deletions onnx_diagnostic/reference/torch_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
continue
opset = self.opsets[node.domain]
key = node.domain, node.op_type, opset
while key not in kernels:
while key not in kernels and opset > 0:
opset -= 1
key = node.domain, node.op_type, opset
assert (
Expand All @@ -135,7 +135,7 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):

def run(
self, outputs: Optional[List[str]], feeds: Dict[str, torch.Tensor]
) -> List[torch.Tensor]:
) -> List[Optional[torch.Tensor]]:
"""
Runs the ONNX model.

Expand All @@ -150,12 +150,20 @@ def run(
for k, v in self.constants.items():
r = self.runtime_info[k]
if not r.has_value:
r.set_value(v.to(self.CUDA) if r.is_shape and self.on_cuda else v)
r.set_value(
torch_ops.OpRunValue(
v.to(self.CUDA) if r.is_shape and self.on_cuda else v, True
)
)

# inputs
for k, v in feeds.items():
r = self.runtime_info[k]
r.set_value(v.to(self.CUDA) if r.is_shape and self.on_cuda else v)
r.set_value(
torch_ops.OpRunValue(
v.to(self.CUDA) if r.is_shape and self.on_cuda else v, False
)
)

# node execution
for it, kernel in enumerate(self.kernels):
Expand All @@ -174,12 +182,12 @@ def run(
self.runtime_info[name].clean_value()

# outputs
res = [self.runtime_info[o].value for o in outputs]
res = [self.runtime_info[o].value.tensor for o in outputs] # type: ignore[assignment, union-attr]

# clean previous execution
for k in feeds:
self.runtime_info[k].clean_value()
for o in outputs:
self.runtime_info[o].clean_value()

return res
return res # type: ignore[return-value]
5 changes: 3 additions & 2 deletions onnx_diagnostic/reference/torch_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._op_run import OpRun

from ._op_run import OpRun, OpRunValue
from .access_ops import Slice_13
from .binary_ops import Add_1, Div_1, Mul_1, Sub_1
from .shape_ops import Squeeze_13
41 changes: 39 additions & 2 deletions onnx_diagnostic/reference/torch_ops/_op_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,43 @@
from typing import Optional, Union, Tuple
import onnx
import torch


class OpRunValue:
"""
Wrapper around a tensor.

:param tensor: torch.Tensor
:param is_constant: is it a constant
"""

__slots__ = ("cached", "is_constant", "tensor")

def __init__(self, tensor, is_constant: bool = False):
self.tensor = tensor
self.is_constant = is_constant
self.cached = None

@property
def shape(self):
"shape"
return self.tensor.shape

@property
def dtype(self):
"dtype"
return self.tensor.dtype

def _tensor_as_tuple_int(self) -> Tuple[int, ...]:
return tuple(map(int, self.tensor))

@property
def as_tuple_int(self):
"value as int"
if self.is_constant:
if self.cached is not None:
self.cached = self._tensor_as_tuple_int()
return self.cached
return self._tensor_as_tuple_int()


class OpRun:
Expand Down Expand Up @@ -31,7 +68,7 @@ def __str__(self) -> str:
)
return f"{self.op_type}({', '.join(self.input)}) -> {', '.join(self.output)}"

def run(self, *args) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
def run(self, *args) -> Union[OpRunValue, Tuple[OpRunValue, ...]]:
"Kernel implementation."
raise NotImplementedError(
f"Method run is not implemented for kernel {self.__class__.__name__!r}"
Expand Down
32 changes: 32 additions & 0 deletions onnx_diagnostic/reference/torch_ops/access_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional
from . import OpRun, OpRunValue


class Slice_13(OpRun):
"Slice"

def run(
self,
data: OpRunValue,
starts: OpRunValue,
ends: OpRunValue,
axes: Optional[OpRunValue] = None,
steps: Optional[OpRunValue] = None,
) -> OpRunValue:
if axes is None:
if steps is None:
slices = [slice(s, e) for s, e in zip(starts.tensor, ends.tensor)]
else:
slices = [
slice(s, e, d) for s, e, d in zip(starts.tensor, ends.tensor, steps.tensor)
]
else:
if steps is None:
slices = [slice(0, a) for a in data.shape]
for s, e, a in zip(starts.tensor, ends.tensor, axes.tensor):
slices[a] = slice(s, e)
else:
slices = [slice(0, a) for a in data.shape]
for s, e, a, d in zip(starts.tensor, ends.tensor, axes.tensor, steps.tensor):
slices[a] = slice(s, e, d)
return OpRunValue(data.tensor[tuple(slices)])
18 changes: 9 additions & 9 deletions onnx_diagnostic/reference/torch_ops/binary_ops.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
from . import OpRun
from . import OpRun, OpRunValue


class Add_1(OpRun):
"""Add"""

def run(self, x, y):
return x + y
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
return OpRunValue(x.tensor + y.tensor)


class Mul_1(OpRun):
"""Mul"""

def run(self, x, y):
return x * y
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
return OpRunValue(x.tensor * y.tensor)


class Div_1(OpRun):
"""Div"""

def run(self, x, y):
return x / y
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
return OpRunValue(x.tensor / y.tensor)


class Sub_1(OpRun):
"""Sub"""

def run(self, x, y):
return x - y
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
return OpRunValue(x.tensor - y.tensor)
8 changes: 8 additions & 0 deletions onnx_diagnostic/reference/torch_ops/shape_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from . import OpRun, OpRunValue


class Squeeze_13(OpRun):
"Squeeze"

def run(self, data: OpRunValue, axes: OpRunValue) -> OpRunValue:
return OpRunValue(data.tensor.squeeze(axes.as_tuple_int))
2 changes: 1 addition & 1 deletion onnx_diagnostic/torch_onnx/runtime_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
name: str,
dtype: Optional[Any] = None,
shape: Optional[Tuple[Union[str, int], ...]] = None,
value: Optional[torch.Tensor] = None,
value: Optional[Any] = None,
first_used: Optional[int] = None,
last_used: Optional[int] = None,
created: Optional[int] = None,
Expand Down
Loading