Skip to content

Commit fe286f1

Browse files
authored
First step for TorchEvaluator (#116)
* draft * First step for TorchEvaluator * changes * mypy * mypy * mypy
1 parent 7053dc2 commit fe286f1

File tree

15 files changed

+467
-8
lines changed

15 files changed

+467
-8
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ Change Logs
44
0.6.1
55
+++++
66

7+
* :pr:`115`, :pr:`116`: first steps for TorchEvaluator
78
* :pr:`114`: extends the list of known rewritings
89
* :pr:`113`: fixes a couple of issues with ModelBuilder
910

1011
0.6.0
1112
+++++
1213

13-
* :pr:`111`: support ModelBuilder with command line validatz
14+
* :pr:`111`: support ModelBuilder with command line validate
1415
* :pr:`108`, :pr:`109`, :pr:`110`: first version of an algorithm rendering
1516
small onnx graph in ascii, patch for ``torch.vmap``
1617

_doc/api/reference/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ onnx_diagnostic.reference
66
:caption: submodules
77

88
ops/index
9+
torch_ops/index
910

1011
.. toctree::
1112
:maxdepth: 1
@@ -14,6 +15,7 @@ onnx_diagnostic.reference
1415
evaluator
1516
quantized_tensor
1617
ort_evaluator
18+
torch_evaluator
1719

1820
ExtendedReferenceEvaluator
1921
++++++++++++++++++++++++++
@@ -27,6 +29,12 @@ OnnxruntimeEvaluator
2729
.. autoclass:: onnx_diagnostic.reference.OnnxruntimeEvaluator
2830
:members:
2931

32+
TorchEvaluator
33+
++++++++++++++
34+
35+
.. autoclass:: onnx_diagnostic.reference.TorchEvaluator
36+
:members:
37+
3038
Other functions
3139
+++++++++++++++
3240

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.reference.torch_evaluator
3+
=========================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_evaluator
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: TorchEvaluator
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.binary_ops
3+
==============================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.binary_ops
6+
:members:
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
2+
onnx_diagnostic.reference.torch_ops
3+
===================================
4+
5+
6+
.. toctree::
7+
:maxdepth: 1
8+
:caption: modules
9+
10+
binary_ops
11+
12+
OpRun
13+
+++++
14+
15+
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRun
16+
:members:
17+
18+
Other functions
19+
+++++++++++++++
20+
21+
.. automodule:: onnx_diagnostic.reference.torch_ops
22+
:members:
23+
:no-undoc-members:

_doc/api/torch_onnx/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ onnx_diagnostic.torch_onnx
55
:maxdepth: 1
66
:caption: submodules
77

8+
runtime_info
89
sbs
910

1011
.. automodule:: onnx_diagnostic.torch_onnx
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_onnx.runtime_info
3+
=======================================
4+
5+
.. automodule:: onnx_diagnostic.torch_onnx.runtime_info
6+
:members:
7+
:no-undoc-members:
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import unittest
2+
import numpy as np
3+
import onnx
4+
import onnx.helper as oh
5+
import onnx.numpy_helper as onh
6+
import torch
7+
from onnx_diagnostic.ext_test_case import ExtTestCase
8+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchEvaluator
9+
from onnx_diagnostic.reference.torch_evaluator import get_kernels
10+
11+
12+
TFLOAT = onnx.TensorProto.FLOAT
13+
14+
15+
class TestTorchEvaluator(ExtTestCase):
16+
def test_kernels(self):
17+
ker = get_kernels()
18+
self.assertIsInstance(ker, dict)
19+
key = "", "Add", 1
20+
self.assertIn(key, ker)
21+
kernel = ker[key]
22+
self.assertEqual("Add_1", kernel.__name__)
23+
24+
def test_binary_ops(self):
25+
model = oh.make_model(
26+
oh.make_graph(
27+
[
28+
oh.make_node("Add", ["X", "un"], ["xy"]),
29+
oh.make_node("Mul", ["xy", "Y"], ["xyy"]),
30+
oh.make_node(
31+
"Constant",
32+
[],
33+
["deux"],
34+
value=onh.from_array(np.array([2], dtype=np.float32)),
35+
),
36+
oh.make_node("Div", ["xyy", "deux"], ["xyyy"]),
37+
oh.make_node("Sub", ["xyyy", "Y"], ["Z"]),
38+
],
39+
"dummy",
40+
[
41+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
42+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b"]),
43+
],
44+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
45+
[onh.from_array(np.array([1], dtype=np.float32), name="un")],
46+
),
47+
ir_version=9,
48+
opset_imports=[oh.make_opsetid("", 18)],
49+
)
50+
onnx.checker.check_model(model)
51+
52+
rt = TorchEvaluator(model)
53+
self.assertEqual(5, len(rt.kernels))
54+
self.assertEqual(2, len(rt.constants))
55+
56+
feeds = dict(
57+
X=torch.rand((4, 5), dtype=torch.float32),
58+
Y=torch.rand((4, 5), dtype=torch.float32),
59+
)
60+
61+
expected = ExtendedReferenceEvaluator(model).run(
62+
None, {k: v.numpy() for k, v in feeds.items()}
63+
)
64+
got = rt.run(None, feeds)
65+
self.assertEqualAny(expected, [g.detach().numpy() for g in got])
66+
self.assertEqual(len(rt.last_used), len(model.graph.node))
67+
self.assertEqual(len(rt.kernels), len(model.graph.node))
68+
self.assertEqual([["X"], ["xy"], [], ["xyy"], ["Y", "xyyy"]], rt.last_used)
69+
for k, v in rt.runtime_info.items():
70+
if k in {"un", "deux"}:
71+
self.assertNotEmpty(v.value)
72+
else:
73+
self.assertEmpty(v.value)
74+
75+
76+
if __name__ == "__main__":
77+
unittest.main(verbosity=2)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .evaluator import ExtendedReferenceEvaluator
22
from .ort_evaluator import OnnxruntimeEvaluator
3+
from .torch_evaluator import TorchEvaluator
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import functools
2+
from typing import Dict, List, Optional, Sequence, Tuple, Union
3+
import onnx
4+
import torch
5+
from ..helpers.torch_helper import to_tensor
6+
from ..torch_onnx.runtime_info import first_used_last_used
7+
from . import torch_ops
8+
9+
10+
@functools.lru_cache
11+
def get_kernels() -> Dict[Tuple[str, str, int], type[torch_ops.OpRun]]:
12+
"""Retrieves all the available kernels."""
13+
res = {}
14+
for _k, v in torch_ops.__dict__.items():
15+
if isinstance(v, type) and issubclass(v, torch_ops.OpRun) and "_" in v.__name__:
16+
name, version = v.__name__.split("_")
17+
domain = getattr(v, "domain", "")
18+
res[domain, name, int(version)] = v
19+
return res
20+
21+
22+
class TorchEvaluator:
23+
"""
24+
Torch evaluator for onnx models.
25+
The model does not stores the original proto it evaluates to avoid
26+
27+
:param proto: a proto
28+
:param providers: where to run the model
29+
:param opsets: needed if proto is a graph
30+
31+
The class holds the following attributes:
32+
33+
* `providers`: providers
34+
* `default_device`: default torch device
35+
* `constants`: all initializers or constants
36+
* `kernels`: kernels
37+
* `runtime_info`: produced by :func:`first_used_last_used
38+
<onnx_diagnostic.torch_onnx.runtime_info.first_used_last_used>`
39+
* `last_used`: contains the list of intermediate results,
40+
to remove after every node execution,
41+
this avoid the memory to grow too much
42+
43+
The class is not multithreaded. `runtime_info` gets updated
44+
by the the class.
45+
"""
46+
47+
def __init__(
48+
self,
49+
proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
50+
providers: Tuple[str, ...] = ("CPUExecutionProvider",),
51+
opsets: Optional[Dict[str, int]] = None,
52+
):
53+
self.providers = providers
54+
self.constants: Dict[str, torch.Tensor] = {}
55+
self.kernels: List[Optional[torch_ops.OpRun]] = []
56+
self.CPU = torch.tensor([0]).to("cpu").device
57+
if "CUDAExecutionProvider" in providers:
58+
self.CUDA = torch.tensor([0]).to("cuda").device
59+
self.default_device = self.CUDA
60+
else:
61+
self.default_device = self.CPU
62+
63+
if isinstance(proto, onnx.ModelProto):
64+
assert opsets is None, "proto is a model, opsets must be None in that case"
65+
assert not proto.graph.sparse_initializer, "sparse_initializer not support yet"
66+
self.opsets = {d.domain: d.version for d in proto.opset_import}
67+
self._build_initializers(proto.graph.initializer)
68+
self._build_initializers(proto.graph.node)
69+
self._build_kernels(proto.graph.node)
70+
self.input_names = [i.name for i in proto.graph.input]
71+
self.output_names = [i.name for i in proto.graph.output]
72+
elif isinstance(proto, onnx.GraphProto):
73+
assert opsets, "opsets must be specified if proto is a graph"
74+
assert not proto.sparse_initializer, "sparse_initializer not support yet"
75+
self.opsets = opsets
76+
self._build_initializers(proto)
77+
self._build_initializers(proto.node)
78+
self._build_kernels(proto.nodes)
79+
self.input_names = [i.name for i in proto.input]
80+
self.output_names = [i.name for i in proto.output]
81+
elif isinstance(proto, onnx.FunctionProto):
82+
assert opsets is None, "proto is a model, opsets must be None in that case"
83+
self.opsets = {d.domain: d.version for d in proto.opset_import}
84+
self._build_initializers(proto.node)
85+
self._build_kernels(proto.node)
86+
self.input_names = list(proto.input)
87+
self.output_names = list(proto.output)
88+
else:
89+
raise TypeError(f"Unexpected type {type(proto)} for proto")
90+
91+
self.runtime_info = first_used_last_used(proto, constant_as_initializer=True)
92+
self.last_used: List[List[str]] = [[] for _ in self.kernels]
93+
for name, info in self.runtime_info.items():
94+
assert isinstance(info.last_used, int), f"Missing field last_used in {info!r}"
95+
if not info.is_output and not info.is_initializer:
96+
self.last_used[info.last_used].append(name)
97+
98+
@property
99+
def on_cuda(self) -> bool:
100+
return self.default_device == self.CUDA
101+
102+
def _build_initializers(self, inits: Sequence[Union[onnx.NodeProto, onnx.TensorProto]]):
103+
for init in inits:
104+
if isinstance(init, onnx.TensorProto):
105+
self.constants[init.name] = to_tensor(init).to(self.default_device)
106+
elif (
107+
isinstance(init, onnx.NodeProto)
108+
and init.op_type == "Constant"
109+
and init.domain == ""
110+
):
111+
value = None
112+
for att in init.attribute:
113+
if att.name == "value":
114+
value = to_tensor(att.t).to(self.default_device)
115+
assert value is not None, f"No attribute value in node {init}"
116+
self.constants[init.output[0]] = value
117+
118+
def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
119+
kernels = get_kernels()
120+
self.kernels.clear()
121+
for node in nodes:
122+
if node.op_type == "Constant" and node.domain == "":
123+
# Treated as a constant.
124+
self.kernels.append(None)
125+
continue
126+
opset = self.opsets[node.domain]
127+
key = node.domain, node.op_type, opset
128+
while key not in kernels:
129+
opset -= 1
130+
key = node.domain, node.op_type, opset
131+
assert (
132+
key in kernels
133+
), f"Missing kernel for node type {node.op_type!r} from domain {node.domain!r}"
134+
self.kernels.append(kernels[key](node, opset))
135+
136+
def run(
137+
self, outputs: Optional[List[str]], feeds: Dict[str, torch.Tensor]
138+
) -> List[torch.Tensor]:
139+
"""
140+
Runs the ONNX model.
141+
142+
:param outputs: outputs required:
143+
:param feeds: inputs
144+
:return: output tensors.
145+
"""
146+
if outputs is None:
147+
outputs = self.output_names
148+
149+
# sets constants
150+
for k, v in self.constants.items():
151+
r = self.runtime_info[k]
152+
if not r.has_value:
153+
r.set_value(v.to(self.CUDA) if r.is_shape and self.on_cuda else v)
154+
155+
# inputs
156+
for k, v in feeds.items():
157+
r = self.runtime_info[k]
158+
r.set_value(v.to(self.CUDA) if r.is_shape and self.on_cuda else v)
159+
160+
# node execution
161+
for it, kernel in enumerate(self.kernels):
162+
if kernel is not None:
163+
# kernel execution
164+
inputs = [(self.runtime_info[i].value if i else None) for i in kernel.input]
165+
res = kernel.run(*inputs)
166+
if isinstance(res, tuple):
167+
for name, t in zip(kernel.output, res):
168+
self.runtime_info[name].set_value(t)
169+
else:
170+
self.runtime_info[kernel.output[0]].set_value(res)
171+
172+
# free intermediate results
173+
for name in self.last_used[it]:
174+
self.runtime_info[name].clean_value()
175+
176+
# outputs
177+
res = [self.runtime_info[o].value for o in outputs]
178+
179+
# clean previous execution
180+
for k in feeds:
181+
self.runtime_info[k].clean_value()
182+
for o in outputs:
183+
self.runtime_info[o].clean_value()
184+
185+
return res

0 commit comments

Comments
 (0)