Skip to content

Commit e649e25

Browse files
authored
Add verbose to OpRunKernel (#136)
* add verbosity * mypy * doc * ci * pymyp
1 parent 3038095 commit e649e25

File tree

15 files changed

+340
-48
lines changed

15 files changed

+340
-48
lines changed

_doc/api/helpers/doc_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.doc_helper
3+
==================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.doc_helper
6+
:members:
7+
:no-undoc-members:

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ onnx_diagnostic.helpers
1010
bench_run
1111
cache_helper
1212
config_helper
13+
doc_helper
1314
graph_helper
1415
helper
1516
memory_peak
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import unittest
2+
import onnx
3+
import onnx.helper as oh
4+
import torch
5+
from onnx_diagnostic.ext_test_case import ExtTestCase
6+
from onnx_diagnostic.helpers.doc_helper import LayerNormalizationOrt, MatMulOrt
7+
from onnx_diagnostic.reference import TorchOnnxEvaluator
8+
9+
TFLOAT = onnx.TensorProto.FLOAT
10+
TFLOAT16 = onnx.TensorProto.FLOAT16
11+
12+
13+
class TestDocHelper(ExtTestCase):
14+
def test_custom_doc_kernels_layer_normalization(self):
15+
model = oh.make_model(
16+
oh.make_graph(
17+
[
18+
oh.make_node(
19+
"LayerNormalization",
20+
["X", "W", "B"],
21+
["ln"],
22+
axis=-1,
23+
epsilon=9.999999974752427e-7,
24+
),
25+
oh.make_node(
26+
"Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
27+
),
28+
],
29+
"dummy",
30+
[
31+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
32+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
33+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
34+
],
35+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
36+
),
37+
ir_version=9,
38+
opset_imports=[oh.make_opsetid("", 18)],
39+
)
40+
41+
torch_sess = TorchOnnxEvaluator(model, verbose=0)
42+
torch_sess_custom = TorchOnnxEvaluator(
43+
model,
44+
verbose=0,
45+
custom_kernels={("", "LayerNormalization"): LayerNormalizationOrt},
46+
)
47+
feeds = dict(
48+
zip(
49+
torch_sess.input_names,
50+
[
51+
torch.rand(3, 4, 5, dtype=torch.float16),
52+
torch.abs(torch.rand(5, dtype=torch.float16)),
53+
torch.rand(5, dtype=torch.float16),
54+
],
55+
)
56+
)
57+
expected = torch_sess.run(None, feeds)
58+
got = torch_sess_custom.run(None, feeds)
59+
self.assertEqualAny(expected, got)
60+
61+
def test_custom_doc_kernels_matmul(self):
62+
model = oh.make_model(
63+
oh.make_graph(
64+
[oh.make_node("MatMul", ["X", "Y"], ["Z"])],
65+
"dummy",
66+
[
67+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
68+
oh.make_tensor_value_info("Y", TFLOAT16, ["b", "d", "e"]),
69+
],
70+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "e"])],
71+
),
72+
ir_version=9,
73+
opset_imports=[oh.make_opsetid("", 18)],
74+
)
75+
76+
torch_sess = TorchOnnxEvaluator(model, verbose=0)
77+
torch_sess_custom = TorchOnnxEvaluator(
78+
model,
79+
verbose=0,
80+
custom_kernels={("", "MatMul"): MatMulOrt},
81+
)
82+
feeds = dict(
83+
zip(
84+
torch_sess.input_names,
85+
[
86+
torch.rand(3, 4, 5, dtype=torch.float16),
87+
torch.rand(3, 5, 7, dtype=torch.float16),
88+
],
89+
)
90+
)
91+
expected = torch_sess.run(None, feeds)
92+
got = torch_sess_custom.run(None, feeds)
93+
self.assertEqualAny(expected, got)
94+
95+
96+
if __name__ == "__main__":
97+
unittest.main(verbosity=2)

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,8 +1378,8 @@ class LayerNormalizationOrt(OpRunKernel):
13781378

13791379
_shared = [0]
13801380

1381-
def __init__(self, node: onnx.NodeProto, version=None):
1382-
super().__init__(node, version)
1381+
def __init__(self, node: onnx.NodeProto, version=None, verbose=0):
1382+
super().__init__(node, version, verbose=verbose)
13831383
self.axis = self.get_attribute_int(node, "axis", -1)
13841384
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
13851385
self.stash_type = onnx_dtype_to_torch_dtype(
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from typing import Dict, Optional, Tuple
2+
import onnx
3+
import onnx.helper as oh
4+
import torch
5+
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
6+
from ..reference.torch_ops import OpRunKernel, OpRunTensor
7+
8+
9+
class LayerNormalizationOrt(OpRunKernel):
10+
"LayerNormalization with onnxruntime"
11+
12+
@classmethod
13+
def device_dependent(cls) -> bool:
14+
"Needs device."
15+
return False
16+
17+
def __init__(
18+
self,
19+
node: onnx.NodeProto,
20+
version=None,
21+
device: Optional[torch.device] = None,
22+
verbose=0,
23+
):
24+
super().__init__(node, version, verbose=verbose)
25+
self.axis = self.get_attribute_int(node, "axis", -1)
26+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
27+
self.device = device
28+
self.stash_type = onnx_dtype_to_torch_dtype(
29+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
30+
)
31+
self.compute_std = len(node.output) > 1
32+
assert not self.compute_std, (
33+
f"This kernel implementation only work when only one output "
34+
f"is required but {node.output} were."
35+
)
36+
self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
37+
self.is_cpu = torch.device("cpu") == self.device
38+
39+
def _make_model(self, itype: int, rank: int) -> onnx.ModelProto:
40+
shape = [*["d{i}" for i in range(rank - 1)], "last"]
41+
layer_model = oh.make_model(
42+
oh.make_graph(
43+
[
44+
oh.make_node(
45+
"LayerNormalization",
46+
["X", "W", "B"],
47+
["Z"],
48+
axis=self.axis,
49+
epsilon=self.epsilon,
50+
)
51+
],
52+
"dummy",
53+
[
54+
oh.make_tensor_value_info("X", itype, shape),
55+
oh.make_tensor_value_info("W", itype, ["last"]),
56+
oh.make_tensor_value_info("B", itype, ["last"]),
57+
],
58+
[oh.make_tensor_value_info("Z", itype, shape)],
59+
),
60+
ir_version=9,
61+
opset_imports=[oh.make_opsetid("", 18)],
62+
)
63+
import onnxruntime
64+
65+
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
66+
return onnxruntime.InferenceSession(
67+
layer_model.SerializeToString(), providers=[provider]
68+
)
69+
70+
def run(self, x, scale, bias=None):
71+
itype = torch_dtype_to_onnx_dtype(x.dtype)
72+
rank = len(x.shape)
73+
key = itype, rank
74+
if key not in self._cache:
75+
self._cache[key] = self._make_model(itype, rank)
76+
sess = self._cache[key]
77+
feeds = dict(X=x, W=scale)
78+
if bias is not None:
79+
feeds["B"] = bias
80+
feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
81+
got = sess.run(None, feeds)[0]
82+
return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
83+
84+
85+
class MatMulOrt(OpRunKernel):
86+
"MatMul with onnxruntime"
87+
88+
@classmethod
89+
def device_dependent(cls) -> bool:
90+
"Needs device."
91+
return False
92+
93+
def __init__(
94+
self,
95+
node: onnx.NodeProto,
96+
version=None,
97+
device: Optional[torch.device] = None,
98+
verbose=0,
99+
):
100+
super().__init__(node, version, verbose=verbose)
101+
self.device = device
102+
self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
103+
self.is_cpu = torch.device("cpu") == self.device
104+
105+
def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
106+
shapea = ["a{i}" for i in range(ranka)]
107+
shapeb = ["b{i}" for i in range(rankb)]
108+
shapec = ["c{i}" for i in range(max(ranka, rankb))]
109+
model = oh.make_model(
110+
oh.make_graph(
111+
[oh.make_node("MatMul", ["A", "B"], ["C"])],
112+
"dummy",
113+
[
114+
oh.make_tensor_value_info("A", itype, shapea),
115+
oh.make_tensor_value_info("B", itype, shapeb),
116+
],
117+
[oh.make_tensor_value_info("C", itype, shapec)],
118+
),
119+
ir_version=9,
120+
opset_imports=[oh.make_opsetid("", 17)],
121+
)
122+
import onnxruntime
123+
124+
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
125+
return onnxruntime.InferenceSession(model.SerializeToString(), providers=[provider])
126+
127+
def run(self, a, b):
128+
itype = torch_dtype_to_onnx_dtype(a.dtype)
129+
ranka, rankb = len(a.shape), len(b.shape)
130+
key = itype, ranka, rankb
131+
if key not in self._cache:
132+
self._cache[key] = self._make_model(itype, ranka, rankb)
133+
sess = self._cache[key]
134+
feeds = dict(A=a, B=b)
135+
feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
136+
got = sess.run(None, feeds)[0]
137+
return OpRunTensor(torch.from_numpy(got).to(a.dtype).to(a.device))

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,19 +410,24 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
410410
kernels = get_kernels()
411411
self.kernels.clear()
412412
for node in nodes:
413+
kernel_kwargs = dict(verbose=max(0, self.verbose - 1))
413414
opset = self.opsets[node.domain]
414415
key = node.domain, node.op_type, opset
415416
if key[:2] in self.custom_kernels:
416417
cls = self.custom_kernels[key[:2]]
417418
ags = [self.default_device] if cls.device_dependent() else []
418419
kws = dict(parent=self) if cls.has_subgraphs() else {}
419-
kernel2 = cls(node, opset, *ags, **kws)
420+
kws.update(kernel_kwargs) # type: ignore[arg-type]
421+
kernel2 = cls(node, opset, *ags, **kws) # type: ignore[arg-type]
420422
self.kernels.append(kernel2)
421423
continue
422424

423425
if (node.domain, node.op_type) in self.functions:
424426
kernel = torch_ops.OpRunFunction(
425-
self.functions[node.domain, node.op_type], node, self.opsets[node.domain]
427+
self.functions[node.domain, node.op_type],
428+
node,
429+
self.opsets[node.domain],
430+
**kernel_kwargs,
426431
)
427432
self.kernels.append(kernel)
428433
continue
@@ -442,7 +447,8 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
442447
cls = kernels[key]
443448
ags = [self.default_device] if cls.device_dependent() else []
444449
kws = dict(parent=self) if cls.has_subgraphs() else {}
445-
kernel2 = cls(node, opset, *ags, **kws)
450+
kws.update(kernel_kwargs) # type: ignore[arg-type]
451+
kernel2 = cls(node, opset, *ags, **kws) # type: ignore[arg-type]
446452
self.kernels.append(kernel2)
447453

448454
def run(

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List, Optional, Union, Tuple
1+
from typing import Any, Dict, List, Optional, Union, Tuple
22
import onnx
33
import torch
44
from ...api import TensorLike
@@ -185,14 +185,22 @@ def has_subgraphs(cls) -> bool:
185185
"""Returns True if the kernel has subgraphs."""
186186
return False
187187

188-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
188+
def __init__(
189+
self,
190+
node: onnx.NodeProto,
191+
version: Optional[int] = None,
192+
verbose: int = 0,
193+
custom_kernels: Optional[Dict[Tuple[str, str], type]] = None,
194+
):
189195
assert isinstance(
190196
node, onnx.NodeProto
191197
), f"node must be a NodeProto but node is {type(node)}"
192198
self.op_type = node.op_type
193199
self.domain = node.domain
194200
self.input = node.input
195201
self.output = node.output
202+
self.verbose = verbose
203+
self.custom_kernels = custom_kernels
196204
if version is None:
197205
name = self.__class__.__name__.split("_")
198206
assert (
@@ -315,8 +323,9 @@ def __init__(
315323
runtime: "onnx_diagnostic.reference.TorchOnnxEvaluator", # noqa: F821
316324
node: onnx.NodeProto,
317325
version: Optional[int] = None,
326+
verbose: int = 0,
318327
):
319-
super().__init__(node, version)
328+
super().__init__(node, version, verbose=verbose)
320329
self.runtime = runtime
321330
self.input_names = runtime.input_names
322331

onnx_diagnostic/reference/torch_ops/access_ops.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
class Gather_1(OpRunKernel):
88
"Gather"
99

10-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
11-
super().__init__(node, version)
10+
def __init__(
11+
self,
12+
node: onnx.NodeProto,
13+
version: Optional[int] = None,
14+
verbose: int = 0,
15+
):
16+
super().__init__(node, version, verbose=verbose)
1217
axis = self.get_attribute_int(node, "axis", 0)
1318
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
1419
self.axis = axis
@@ -24,8 +29,13 @@ def run(self, x, indices):
2429
class ScatterND_16(OpRunKernel):
2530
"ScatterND"
2631

27-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
28-
super().__init__(node, version)
32+
def __init__(
33+
self,
34+
node: onnx.NodeProto,
35+
version: Optional[int] = None,
36+
verbose: int = 0,
37+
):
38+
super().__init__(node, version, verbose=verbose)
2939
self.reduction = self.get_attribute_string(node, "reduction", "none")
3040

3141
def run(

0 commit comments

Comments
 (0)