Skip to content

Commit 760d46f

Browse files
committed
add verbosity
1 parent 3038095 commit 760d46f

File tree

11 files changed

+96
-46
lines changed

11 files changed

+96
-46
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,8 +1378,8 @@ class LayerNormalizationOrt(OpRunKernel):
13781378

13791379
_shared = [0]
13801380

1381-
def __init__(self, node: onnx.NodeProto, version=None):
1382-
super().__init__(node, version)
1381+
def __init__(self, node: onnx.NodeProto, version=None, verbose=0):
1382+
super().__init__(node, version, verbose=verbose)
13831383
self.axis = self.get_attribute_int(node, "axis", -1)
13841384
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
13851385
self.stash_type = onnx_dtype_to_torch_dtype(

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,19 +410,24 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
410410
kernels = get_kernels()
411411
self.kernels.clear()
412412
for node in nodes:
413+
kernel_kwargs = dict(verbose=max(0, self.verbose - 1))
413414
opset = self.opsets[node.domain]
414415
key = node.domain, node.op_type, opset
415416
if key[:2] in self.custom_kernels:
416417
cls = self.custom_kernels[key[:2]]
417418
ags = [self.default_device] if cls.device_dependent() else []
418419
kws = dict(parent=self) if cls.has_subgraphs() else {}
420+
kws.update(kernel_kwargs)
419421
kernel2 = cls(node, opset, *ags, **kws)
420422
self.kernels.append(kernel2)
421423
continue
422424

423425
if (node.domain, node.op_type) in self.functions:
424426
kernel = torch_ops.OpRunFunction(
425-
self.functions[node.domain, node.op_type], node, self.opsets[node.domain]
427+
self.functions[node.domain, node.op_type],
428+
node,
429+
self.opsets[node.domain],
430+
**kernel_kwargs,
426431
)
427432
self.kernels.append(kernel)
428433
continue
@@ -442,6 +447,7 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
442447
cls = kernels[key]
443448
ags = [self.default_device] if cls.device_dependent() else []
444449
kws = dict(parent=self) if cls.has_subgraphs() else {}
450+
kws.update(kernel_kwargs)
445451
kernel2 = cls(node, opset, *ags, **kws)
446452
self.kernels.append(kernel2)
447453

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List, Optional, Union, Tuple
1+
from typing import Any, Dict, List, Optional, Union, Tuple
22
import onnx
33
import torch
44
from ...api import TensorLike
@@ -185,14 +185,22 @@ def has_subgraphs(cls) -> bool:
185185
"""Returns True if the kernel has subgraphs."""
186186
return False
187187

188-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
188+
def __init__(
189+
self,
190+
node: onnx.NodeProto,
191+
version: Optional[int] = None,
192+
verbose: int = 0,
193+
custom_kernels: Optional[Dict[Tuple[str, str], type]] = None,
194+
):
189195
assert isinstance(
190196
node, onnx.NodeProto
191197
), f"node must be a NodeProto but node is {type(node)}"
192198
self.op_type = node.op_type
193199
self.domain = node.domain
194200
self.input = node.input
195201
self.output = node.output
202+
self.verbose = verbose
203+
self.custom_kernels = custom_kernels
196204
if version is None:
197205
name = self.__class__.__name__.split("_")
198206
assert (
@@ -315,8 +323,9 @@ def __init__(
315323
runtime: "onnx_diagnostic.reference.TorchOnnxEvaluator", # noqa: F821
316324
node: onnx.NodeProto,
317325
version: Optional[int] = None,
326+
verbose: int = 0,
318327
):
319-
super().__init__(node, version)
328+
super().__init__(node, version, verbose=verbose)
320329
self.runtime = runtime
321330
self.input_names = runtime.input_names
322331

onnx_diagnostic/reference/torch_ops/access_ops.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
class Gather_1(OpRunKernel):
88
"Gather"
99

10-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
11-
super().__init__(node, version)
10+
def __init__(
11+
self,
12+
node: onnx.NodeProto,
13+
version: Optional[int] = None,
14+
verbose: int = 0,
15+
):
16+
super().__init__(node, version, verbose=verbose)
1217
axis = self.get_attribute_int(node, "axis", 0)
1318
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
1419
self.axis = axis
@@ -24,8 +29,13 @@ def run(self, x, indices):
2429
class ScatterND_16(OpRunKernel):
2530
"ScatterND"
2631

27-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
28-
super().__init__(node, version)
32+
def __init__(
33+
self,
34+
node: onnx.NodeProto,
35+
version: Optional[int] = None,
36+
verbose: int = 0,
37+
):
38+
super().__init__(node, version, verbose=verbose)
2939
self.reduction = self.get_attribute_string(node, "reduction", "none")
3040

3141
def run(

onnx_diagnostic/reference/torch_ops/controlflow_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ def __init__(
1717
node: onnx.NodeProto,
1818
version: Optional[int] = None,
1919
parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
20+
verbose: int = 0,
2021
):
21-
super().__init__(node, version)
22+
super().__init__(node, version, verbose=verbose)
2223
assert (
2324
parent is not None
2425
), f"parent must be specified for operator {self.__class__.__name__!r}"
@@ -30,6 +31,7 @@ def __init__(
3031
opsets=parent.opsets,
3132
local_functions=parent.functions,
3233
verbose=parent.verbose,
34+
custom_kernels=parent.custom_kernels,
3335
)
3436
setattr(self, att.name, rt)
3537

@@ -50,8 +52,9 @@ def __init__(
5052
node: onnx.NodeProto,
5153
version: Optional[int] = None,
5254
parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
55+
verbose: int = 0,
5356
):
54-
super().__init__(node, version, parent)
57+
super().__init__(node, version, parent, verbose=verbose)
5558
self.output_index = {n: i for i, n in enumerate(self.body.output_names)}
5659
self.N = len(self.body.input_names) - 2
5760
self.K = len(self.body.output_names) - self.N - 1

onnx_diagnostic/reference/torch_ops/generator_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def __init__(
1919
node: onnx.NodeProto,
2020
version: Optional[int] = None,
2121
device: Optional[torch.device] = None,
22+
verbose: int = 0,
2223
):
23-
super().__init__(node, version)
24+
super().__init__(node, version, verbose=verbose)
2425
self.device = device
2526

2627
def run(self, starts: OpRunTensor, limit: OpRunTensor, delta: OpRunTensor) -> OpRunTensor:

onnx_diagnostic/reference/torch_ops/nn_ops.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
class AveragePool_11(OpRunKernel):
99
"AveragePool"
1010

11-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12-
super().__init__(node, version)
11+
def __init__(
12+
self,
13+
node: onnx.NodeProto,
14+
version: Optional[int] = None,
15+
verbose: int = 0,
16+
):
17+
super().__init__(node, version, verbose=verbose)
1318
self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
1419
self.ceil_mode = bool(self.get_attribute_int(node, "ceil_mode", 0))
1520
self.count_include_pad = bool(self.get_attribute_int(node, "count_include_pad", 0))
@@ -47,8 +52,13 @@ def run(self, x):
4752
class Conv_11(OpRunKernel):
4853
"Conv"
4954

50-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
51-
super().__init__(node, version)
55+
def __init__(
56+
self,
57+
node: onnx.NodeProto,
58+
version: Optional[int] = None,
59+
verbose: int = 0,
60+
):
61+
super().__init__(node, version, verbose=verbose)
5262
self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
5363
self.dilations = self.get_attribute_ints(node, "dilations", None)
5464
self.group = self.get_attribute_int(node, "group", 1)
@@ -111,8 +121,13 @@ def run(self, x, w, b=None):
111121
class LayerNormalization_17(OpRunKernel):
112122
"LayerNormalization"
113123

114-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
115-
super().__init__(node, version)
124+
def __init__(
125+
self,
126+
node: onnx.NodeProto,
127+
version: Optional[int] = None,
128+
verbose: int = 0,
129+
):
130+
super().__init__(node, version, verbose=verbose)
116131
self.axis = self.get_attribute_int(node, "axis", -1)
117132
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
118133
self.stash_type = onnx_dtype_to_torch_dtype(
@@ -155,8 +170,13 @@ def run(self, x, scale, bias=None):
155170
class Softmax_13(OpRunKernel):
156171
"Softmax"
157172

158-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
159-
super().__init__(node, version)
173+
def __init__(
174+
self,
175+
node: onnx.NodeProto,
176+
version: Optional[int] = None,
177+
verbose: int = 0,
178+
):
179+
super().__init__(node, version, verbose=verbose)
160180
self.axis = self.get_attribute_int(node, "axis", -1)
161181
assert isinstance(self.axis, int), f"Unexpected value for attribute axis={self.axis!r}"
162182
# this is out of spec

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
class Cast_6(OpRunKernel):
99
"Cast"
1010

11-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12-
super().__init__(node, version)
11+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
12+
super().__init__(node, version, verbose=verbose)
1313
to = self.get_attribute_int(node, "to", 0)
1414
assert isinstance(to, int), f"Unexpected value for attribute to={to!r}"
1515
self.to = onnx_dtype_to_torch_dtype(to)
@@ -23,8 +23,8 @@ def run(self, data: OpRunTensor) -> OpRunTensor:
2323
class CastLike_15(OpRunKernel):
2424
"Cast"
2525

26-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
27-
super().__init__(node, version)
26+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
27+
super().__init__(node, version, verbose=verbose)
2828
self.saturate = self.get_attribute_int(node, "saturate", 1)
2929
assert self.saturate == 1, f"saturate={self.saturate} not implemented for CastLike"
3030

@@ -35,8 +35,8 @@ def run(self, data: OpRunTensor, like: OpRunTensor) -> OpRunTensor:
3535
class Concat_1(OpRunKernel):
3636
"Concat"
3737

38-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
39-
super().__init__(node, version)
38+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
39+
super().__init__(node, version, verbose=verbose)
4040
axis = self.get_attribute_int(node, "axis", 0)
4141
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
4242
self.axis = axis
@@ -76,8 +76,8 @@ def run(self, x: OpRunTensor, repeat: OpRunTensor) -> OpRunTensor:
7676
class Transpose_1(OpRunKernel):
7777
"Transpose"
7878

79-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
80-
super().__init__(node, version)
79+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
80+
super().__init__(node, version, verbose=verbose)
8181
self.perm = self.get_attribute_ints(node, "perm", None)
8282

8383
def run(self, data: OpRunTensor) -> OpRunTensor:
@@ -87,8 +87,8 @@ def run(self, data: OpRunTensor) -> OpRunTensor:
8787
class Trilu_14(OpRunKernel):
8888
"Trilu"
8989

90-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
91-
super().__init__(node, version)
90+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
91+
super().__init__(node, version, verbose=verbose)
9292
self.upper = self.get_attribute_int(node, "upper", 1)
9393

9494
def run(self, data: OpRunTensor, k: Optional[OpRunTensor] = None) -> OpRunTensor:

onnx_diagnostic/reference/torch_ops/reduce_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77

88
class ReduceOp(OpRunKernel):
9-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
10-
super().__init__(node, version)
9+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
10+
super().__init__(node, version, verbose=verbose)
1111
self.keepdims = bool(self.get_attribute_int(node, "keepdims", 1))
1212
self.noop_with_empty_axes = bool(
1313
self.get_attribute_int(node, "noop_with_empty_axes", 0)
@@ -28,8 +28,8 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
2828

2929

3030
class ReduceOpAxes(ReduceOp):
31-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
32-
super().__init__(node, version)
31+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
32+
super().__init__(node, version, verbose=verbose)
3333
self.axes: Tuple[int, ...] = self.get_attribute_ints(node, "axes") or tuple()
3434

3535

onnx_diagnostic/reference/torch_ops/sequence_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ class OpRunOpSequence(OpRunKernel):
1212
class ConcatFromSequence_11(OpRunOpSequence):
1313
"ConcatFromSequence"
1414

15-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
16-
super().__init__(node, version)
15+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
16+
super().__init__(node, version, verbose=verbose)
1717
axis = self.get_attribute_int(node, "axis", None)
1818
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
1919
self.axis = axis
@@ -39,8 +39,8 @@ def run(self, input_sequence: OpRunSequence) -> OpRunTensor:
3939
class SequenceEmpty_11(OpRunOpSequence):
4040
"SqeuenceEmpty"
4141

42-
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
43-
super().__init__(node, version)
42+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
43+
super().__init__(node, version, verbose=verbose)
4444
self.dtype = onnx_dtype_to_torch_dtype(
4545
self.get_attribute_int(node, "dtype", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
4646
)

0 commit comments

Comments
 (0)