Skip to content

Commit b8b48bc

Browse files
authored
Add local functions to TorchOnnxEvaluator (#122)
* Add local functions * mypy * mypy * mypy * mypy
1 parent bd3b352 commit b8b48bc

File tree

9 files changed

+220
-23
lines changed

9 files changed

+220
-23
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.6.1
55
+++++
66

7+
* :pr:`122`: add local functions to TorchOnnxEvaluator
78
* :pr:`120`: enables TorchOnnxEvaluator in command line ``python -m onnx_diagnostic validate ...``
89
* :pr:`115`, :pr:`116`, :pr:`117`, :pr:`118`, :pr:`119`:
910
first steps for TorchOnnxEvaluator

_doc/api/reference/torch_ops/index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ OpRunValue
2222
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunValue
2323
:members:
2424

25+
OpRunFunction
26+
+++++++++++++
27+
28+
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunFunction
29+
:members:
30+
2531
Other functions
2632
+++++++++++++++
2733

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def test_op_binary_cmp(self):
119119
ir_version=9,
120120
opset_imports=[oh.make_opsetid("", 18)],
121121
)
122-
onnx.checker.check_model(model)
123122
self._finalize_test(
124123
model,
125124
torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)),
@@ -853,6 +852,51 @@ def test_op_trilu_k(self):
853852
torch.tensor([2], dtype=torch.int64),
854853
)
855854

855+
def test_local_function(self):
856+
new_domain = "custom"
857+
858+
linear_regression = oh.make_function(
859+
new_domain,
860+
"LinearRegression",
861+
["x", "a", "b"],
862+
["y"],
863+
[
864+
oh.make_node("MatMul", ["x", "a"], ["xa"]),
865+
oh.make_node("Add", ["xa", "b"], ["y"]),
866+
],
867+
[oh.make_opsetid("", 18)],
868+
[],
869+
)
870+
871+
graph = oh.make_graph(
872+
[
873+
oh.make_node("LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain),
874+
oh.make_node("Abs", ["Y1"], ["Y"]),
875+
],
876+
"example",
877+
[
878+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
879+
oh.make_tensor_value_info("A", TFLOAT, ["a", "b"]),
880+
oh.make_tensor_value_info("B", TFLOAT, ["a", "b"]),
881+
],
882+
[oh.make_tensor_value_info("Y", TFLOAT, ["a", "b"])],
883+
)
884+
885+
model = oh.make_model(
886+
graph,
887+
opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid(new_domain, 1)],
888+
functions=[linear_regression],
889+
ir_version=10,
890+
)
891+
self.assertNotEmpty(model.functions)
892+
self._finalize_test(
893+
model,
894+
torch.rand((3, 3), dtype=torch.float32),
895+
torch.rand((3, 3), dtype=torch.float32),
896+
torch.rand((3, 3), dtype=torch.float32),
897+
atol=1e-6,
898+
)
899+
856900

857901
if __name__ == "__main__":
858902
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class TorchOnnxEvaluator:
4343
:param proto: a proto
4444
:param providers: where to run the model
4545
:param opsets: needed if proto is a graph
46+
:param functions: known local functions
4647
4748
The class holds the following attributes:
4849
@@ -66,10 +67,12 @@ def __init__(
6667
proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
6768
providers: Tuple[str, ...] = ("CPUExecutionProvider",),
6869
opsets: Optional[Dict[str, int]] = None,
70+
local_functions: Optional[Dict[Tuple[str, str], "TorchOnnxEvaluator"]] = None,
6971
):
7072
self.providers = providers
7173
self.constants: Dict[str, torch.Tensor] = {}
7274
self.kernels: List[Optional[torch_ops.OpRun]] = []
75+
self.functions = local_functions.copy() if local_functions else {}
7376
self.CPU = torch.tensor([0]).to("cpu").device
7477
if "CUDAExecutionProvider" in providers:
7578
self.CUDA = torch.tensor([0]).to("cuda").device
@@ -83,6 +86,10 @@ def __init__(
8386
assert opsets is None, "proto is a model, opsets must be None in that case"
8487
assert not proto.graph.sparse_initializer, "sparse_initializer not support yet"
8588
self.opsets = {d.domain: d.version for d in proto.opset_import}
89+
for f in proto.functions:
90+
self.functions[f.domain, f.name] = TorchOnnxEvaluator(
91+
f, providers=providers, local_functions=self.functions
92+
)
8693
self._build_initializers(proto.graph.initializer)
8794
self._build_initializers(proto.graph.node)
8895
self._build_kernels(proto.graph.node)
@@ -138,24 +145,33 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
138145
kernels = get_kernels()
139146
self.kernels.clear()
140147
for node in nodes:
148+
if (node.domain, node.op_type) in self.functions:
149+
kernel = torch_ops.OpRunFunction(
150+
self.functions[node.domain, node.op_type], node, self.opsets[node.domain]
151+
)
152+
self.kernels.append(kernel)
153+
continue
154+
141155
if node.op_type == "Constant" and node.domain == "":
142156
# Treated as a constant.
143157
self.kernels.append(None)
144158
continue
159+
145160
opset = self.opsets[node.domain]
146161
key = node.domain, node.op_type, opset
147162
while key not in kernels and opset > 0:
148163
opset -= 1
149164
key = node.domain, node.op_type, opset
150-
assert (
151-
key in kernels
152-
), f"Missing kernel for node type {node.op_type!r} from domain {node.domain!r}"
165+
assert key in kernels, (
166+
f"Missing kernel for node type {node.op_type!r} from domain {node.domain!r}, "
167+
f"local functions={sorted(self.functions)}"
168+
)
153169
cls = kernels[key]
154170
if cls.device_dependent():
155-
kernel = cls(node, opset, self.default_device) # type: ignore[call-arg]
171+
kernel2: torch_ops.OpRun = cls(node, opset, self.default_device) # type: ignore[call-arg]
156172
else:
157-
kernel = cls(node, opset)
158-
self.kernels.append(kernel)
173+
kernel2 = cls(node, opset) # type: ignore[assignment]
174+
self.kernels.append(kernel2)
159175

160176
def run(
161177
self,
@@ -165,7 +181,7 @@ def run(
165181
"""
166182
Runs the ONNX model.
167183
168-
:param outputs: outputs required:
184+
:param outputs: outputs required
169185
:param feeds: inputs
170186
:return: output tensors.
171187
"""
@@ -218,7 +234,10 @@ def run(
218234
for name in self.last_used[it]:
219235
self.runtime_info[name].clean_value()
220236

221-
res = [self.runtime_info[o].value.tensor for o in outputs] # type: ignore[assignment, union-attr]
237+
assert all(
238+
self.runtime_info[o].value is not None for o in outputs
239+
), "Not implemented yet when one output is None."
240+
fres = [self.runtime_info[o].value.tensor for o in outputs] # type: ignore[union-attr]
222241

223242
# clean previous execution
224243
for k in feeds:
@@ -227,5 +246,71 @@ def run(
227246
self.runtime_info[o].clean_value()
228247

229248
if use_numpy:
230-
return [None if a is None else a.detach().cpu().numpy() for a in res] # type: ignore[union-attr]
231-
return res # type: ignore[return-value]
249+
return [None if a is None else a.detach().cpu().numpy() for a in fres]
250+
return fres
251+
252+
def run_with_values(
253+
self, *args: Optional[torch_ops.OpRunValue]
254+
) -> Union[torch_ops.OpRunValue, Tuple[torch_ops.OpRunValue, ...]]:
255+
"""
256+
Runs the ONNX model.
257+
258+
:param args: inputs
259+
:return: output OpRunValue
260+
"""
261+
assert all(
262+
isinstance(a, torch_ops.OpRunValue) for a in args
263+
), f"Unexpected type in args: {[type(a) for a in args]}"
264+
outputs = self.output_names
265+
266+
# sets constants
267+
for k, v in self.constants.items():
268+
r = self.runtime_info[k]
269+
if not r.has_value:
270+
r.set_value(
271+
torch_ops.OpRunValue(
272+
v.to(self.CUDA) if r.is_shape and self.on_cuda else v, True
273+
)
274+
)
275+
276+
# inputs
277+
for k, v in zip(self.input_names, args):
278+
r = self.runtime_info[k]
279+
r.set_value(torch_ops.OpRunValue(None if v is None else v.tensor))
280+
281+
# node execution
282+
for it, kernel in enumerate(self.kernels):
283+
if kernel is not None:
284+
# kernel execution
285+
inputs = [(self.runtime_info[i].value if i else None) for i in kernel.input]
286+
res = kernel.run(*inputs)
287+
if isinstance(res, tuple):
288+
# outputs
289+
assert all(isinstance(o, torch_ops.OpRunValue) for o in res), (
290+
f"Unexpected output type {[type(o) for o in res]} "
291+
f"for kernel {type(kernel)}."
292+
)
293+
for name, t in zip(kernel.output, res):
294+
self.runtime_info[name].set_value(t)
295+
else:
296+
assert isinstance(
297+
res, torch_ops.OpRunValue
298+
), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
299+
self.runtime_info[kernel.output[0]].set_value(res)
300+
301+
# free intermediate results
302+
for name in self.last_used[it]:
303+
self.runtime_info[name].clean_value()
304+
305+
assert all(
306+
self.runtime_info[o].value is not None for o in outputs
307+
), "Not implemented yet when one output is None."
308+
res2 = [torch_ops.OpRunValue(self.runtime_info[o].value.tensor) for o in outputs] # type: ignore[assignment, union-attr]
309+
310+
# clean previous execution
311+
for k in self.input_names:
312+
self.runtime_info[k].clean_value()
313+
for o in self.output_names:
314+
self.runtime_info[o].clean_value()
315+
316+
return res2[0] if len(res2) == 1 else tuple(res2) # type: ignore[index, return-value, arg-type]

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._op_run import OpRun, OpRunValue
1+
from ._op_run import OpRun, OpRunFunction, OpRunValue
22
from .access_ops import Gather_1, Slice_13
33
from .binary_ops import (
44
And_1,
@@ -29,6 +29,7 @@
2929
Unsqueeze_13,
3030
)
3131
from .unary_ops import (
32+
Abs_1,
3233
Cos_1,
3334
Exp_1,
3435
Log_1,

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def __str__(self) -> str:
9090
)
9191
return f"{self.op_type}({', '.join(self.input)}) -> {', '.join(self.output)}"
9292

93-
def run(self, *args) -> Union[OpRunValue, Tuple[OpRunValue, ...]]:
93+
def run(
94+
self, *args: Optional[OpRunValue]
95+
) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
9496
"Kernel implementation."
9597
raise NotImplementedError(
9698
f"Method run is not implemented for kernel {self.__class__.__name__!r}"
@@ -157,3 +159,24 @@ def get_attribute_tensor(self, node: onnx.NodeProto, name: str) -> Optional[torc
157159
if att is None:
158160
return None
159161
return to_tensor(att.t)
162+
163+
164+
class OpRunFunction(OpRun):
165+
"""
166+
Defines a kernel based on a local functions.
167+
"""
168+
169+
def __init__(
170+
self,
171+
runtime: "onnx_diagnostic.reference.TorchOnnxEvaluator", # noqa: F821
172+
node: onnx.NodeProto,
173+
version: Optional[int] = None,
174+
):
175+
super().__init__(node, version)
176+
self.runtime = runtime
177+
self.input_names = runtime.input_names
178+
179+
def run(
180+
self, *args: Optional[OpRunValue]
181+
) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
182+
return self.runtime.run_with_values(*args)

onnx_diagnostic/reference/torch_ops/unary_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
from . import OpRun, OpRunValue
33

44

5+
class Abs_1(OpRun):
6+
"""Abs"""
7+
8+
def run(self, x: OpRunValue) -> OpRunValue:
9+
return OpRunValue(torch.abs(x.tensor))
10+
11+
512
class Cos_1(OpRun):
613
"""Cos"""
714

onnx_diagnostic/torch_onnx/runtime_info.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,28 +167,49 @@ def set_is_shape(node: onnx.NodeProto, values: Dict[str, RuntimeValue]) -> List[
167167

168168

169169
def first_used_last_used(
170-
proto: onnx.ModelProto, constant_as_initializer: bool = False
170+
proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
171+
constant_as_initializer: bool = False,
171172
) -> Dict[str, RuntimeValue]:
172173
"""
173174
Builds first used, last used information for every result
174175
in the model.
175176
176-
:param proto: model
177+
:param proto: model, graph or function
177178
:param constant_as_initializer: outputs of node Constant is tagged as INITIALIZER
178179
:return: dictionary of RuntimeValue
179180
"""
180181
values = {}
181-
for init in proto.graph.initializer:
182+
if isinstance(proto, onnx.ModelProto):
183+
initializer = proto.graph.initializer
184+
sparse_initializer = proto.graph.sparse_initializer
185+
_input = proto.graph.input
186+
output = proto.graph.output
187+
_node = proto.graph.node
188+
elif isinstance(proto, onnx.GraphProto):
189+
initializer = proto.initializer
190+
sparse_initializer = proto.sparse_initializer
191+
_input = proto.input
192+
output = proto.output
193+
_node = proto.node
194+
else:
195+
initializer = []
196+
sparse_initializer = []
197+
_input = proto.input
198+
output = proto.output
199+
_node = proto.node
200+
201+
for init in initializer:
182202
values[init.name] = RuntimeValue(
183203
init.name, kind=RuntimeValueKind.INITIALIZER, created=-1
184204
)
185-
for init in proto.graph.sparse_initializer:
205+
for init in sparse_initializer:
186206
values[init.name] = RuntimeValue(
187207
init.name, created=-1, kind=RuntimeValueKind.INITIALIZER
188208
)
189-
for inp in proto.graph.input:
190-
values[inp.name] = RuntimeValue(inp.name, created=-1, kind=RuntimeValueKind.INPUT)
191-
for it, node in enumerate(proto.graph.node):
209+
for inp in _input:
210+
n = inp if isinstance(inp, str) else inp.name
211+
values[n] = RuntimeValue(n, created=-1, kind=RuntimeValueKind.INPUT)
212+
for it, node in enumerate(_node):
192213
for i in node.input:
193214
if values[i].first_used is None:
194215
values[i].first_used = it
@@ -212,7 +233,8 @@ def first_used_last_used(
212233
)
213234
set_is_shape(node, values)
214235

215-
for out in proto.graph.output:
216-
values[out.name].kind = RuntimeValueKind.OUTPUT
217-
values[out.name].last_used = len(proto.graph.node)
236+
for out in output:
237+
n = out if isinstance(out, str) else out.name
238+
values[n].kind = RuntimeValueKind.OUTPUT
239+
values[n].last_used = len(_node)
218240
return values

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ disable_error_code = ["assignment", "arg-type", "name-defined", "union-attr"]
4444
module = ["onnx_diagnostic.helpers.ort_session"]
4545
disable_error_code = ["union-attr"]
4646

47+
[[tool.mypy.overrides]]
48+
module = ["onnx_diagnostic.reference.torch_ops.*"]
49+
disable_error_code = ["override"]
50+
51+
[[tool.mypy.overrides]]
52+
module = ["onnx_diagnostic.reference.torch_ops._op_run"]
53+
disable_error_code = ["name-defined"]
54+
4755
[[tool.mypy.overrides]]
4856
module = ["onnx_diagnostic.torch_export_patches.*"]
4957
disable_error_code = ["arg-type", "assignment", "attr-defined", "index", "misc", "name-defined", "operator", "return-value", "union-attr"]

0 commit comments

Comments
 (0)