Skip to content

Commit 15fe928

Browse files
committed
Rename OpRun into OpRunKernel
1 parent c87ef68 commit 15fe928

File tree

15 files changed

+64
-64
lines changed

15 files changed

+64
-64
lines changed

_doc/api/reference/torch_ops/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ onnx_diagnostic.reference.torch_ops
1818
shape_ops
1919
unary_ops
2020

21-
OpRun
22-
+++++
21+
OpRunKernel
22+
+++++++++++
2323

24-
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRun
24+
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunKernel
2525
:members:
2626

2727
OpRunTensor

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
99
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
1010
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchOnnxEvaluator
11-
from onnx_diagnostic.reference.torch_ops import OpRun, OpRunTensor
11+
from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
1212
from onnx_diagnostic.reference.torch_evaluator import get_kernels
1313

1414

@@ -1373,7 +1373,7 @@ def test_tile(self):
13731373
)
13741374

13751375
def test_custom_kernels(self):
1376-
class LayerNormalizationOrt(OpRun):
1376+
class LayerNormalizationOrt(OpRunKernel):
13771377
"LayerNormalization"
13781378

13791379
_shared = [0]

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
@functools.lru_cache
12-
def get_kernels() -> Dict[Tuple[str, str, int], type[torch_ops.OpRun]]:
12+
def get_kernels() -> Dict[Tuple[str, str, int], type[torch_ops.OpRunKernel]]:
1313
"""
1414
Retrieves all the available kernels class :class:`TorchOnnxEvaluator`
1515
can use. The full list is the following.
@@ -28,7 +28,7 @@ def get_kernels() -> Dict[Tuple[str, str, int], type[torch_ops.OpRun]]:
2828
"""
2929
res = {}
3030
for _k, v in torch_ops.__dict__.items():
31-
if isinstance(v, type) and issubclass(v, torch_ops.OpRun) and "_" in v.__name__:
31+
if isinstance(v, type) and issubclass(v, torch_ops.OpRunKernel) and "_" in v.__name__:
3232
name, version = v.__name__.split("_")
3333
domain = getattr(v, "domain", "")
3434
res[domain, name, int(version)] = v
@@ -161,11 +161,11 @@ class TorchOnnxEvaluator:
161161
from onnx_diagnostic.helpers import string_type
162162
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
163163
from onnx_diagnostic.reference import TorchOnnxEvaluator
164-
from onnx_diagnostic.reference.torch_ops import OpRun, OpRunTensor
164+
from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
165165
166166
TFLOAT16 = onnx.TensorProto.FLOAT16
167167
168-
class LayerNormalizationOrt(OpRun):
168+
class LayerNormalizationOrt(OpRunKernel):
169169
"LayerNormalization based on onnxruntime"
170170
171171
def __init__(self, node: onnx.NodeProto, version=None):
@@ -284,11 +284,11 @@ def __init__(
284284
opsets: Optional[Dict[str, int]] = None,
285285
local_functions: Optional[Dict[Tuple[str, str], "TorchOnnxEvaluator"]] = None,
286286
verbose: int = 0,
287-
custom_kernels: Optional[Dict[Tuple[str, str], type[torch_ops.OpRun]]] = None,
287+
custom_kernels: Optional[Dict[Tuple[str, str], type[torch_ops.OpRunKernel]]] = None,
288288
):
289289
self.providers = providers
290290
self.constants: Dict[str, torch.Tensor] = {}
291-
self.kernels: List[Optional[torch_ops.OpRun]] = []
291+
self.kernels: List[Optional[torch_ops.OpRunKernel]] = []
292292
self.functions = local_functions.copy() if local_functions else {}
293293
self.CPU = torch.tensor([0]).to("cpu").device
294294
self.verbose = verbose

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._op_run import OpRun, OpRunFunction, OpRunSequence, OpRunTensor, OpRunValue
1+
from ._op_run import OpRunKernel, OpRunFunction, OpRunSequence, OpRunTensor, OpRunValue
22
from .access_ops import Gather_1, ScatterND_16, Slice_13
33
from .binary_ops import (
44
And_1,

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def string_type(self) -> str:
167167
return string_type(self.sequence, with_shape=True)
168168

169169

170-
class OpRun:
170+
class OpRunKernel:
171171
"""
172172
Main class. Every kernel should inherit from it.
173173
It does not copy the proto.
@@ -305,7 +305,7 @@ def same_device(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
305305
return tuple(t.to(device) for t in tensors)
306306

307307

308-
class OpRunFunction(OpRun):
308+
class OpRunFunction(OpRunKernel):
309309
"""
310310
Defines a kernel based on a local functions.
311311
"""

onnx_diagnostic/reference/torch_ops/access_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from typing import Optional
22
import onnx
33
import torch
4-
from . import OpRun, OpRunTensor
4+
from . import OpRunKernel, OpRunTensor
55

66

7-
class Gather_1(OpRun):
7+
class Gather_1(OpRunKernel):
88
"Gather"
99

1010
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
@@ -21,7 +21,7 @@ def run(self, x, indices):
2121
return OpRunTensor(x.tensor[tuple(ind)])
2222

2323

24-
class ScatterND_16(OpRun):
24+
class ScatterND_16(OpRunKernel):
2525
"ScatterND"
2626

2727
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
@@ -54,7 +54,7 @@ def run(
5454
return OpRunTensor(output)
5555

5656

57-
class Slice_13(OpRun):
57+
class Slice_13(OpRunKernel):
5858
"Slice"
5959

6060
def run(

onnx_diagnostic/reference/torch_ops/binary_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
2-
from . import OpRun, OpRunTensor
2+
from . import OpRunKernel, OpRunTensor
33

44

5-
class OpRunBinary(OpRun):
5+
class OpRunBinary(OpRunKernel):
66
"Binary Op"
77

88
def run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:

onnx_diagnostic/reference/torch_ops/controlflow_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from typing import Any, Dict, Optional
22
import onnx
33
import torch
4-
from . import OpRun, OpRunTensor
4+
from . import OpRunKernel, OpRunTensor
55

66

7-
class OpRunControlFlow(OpRun):
7+
class OpRunControlFlow(OpRunKernel):
88
"""Common ancestor for control flows."""
99

1010
@classmethod

onnx_diagnostic/reference/torch_ops/generator_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from typing import Optional
22
import onnx
33
import torch
4-
from . import OpRun, OpRunTensor
4+
from . import OpRunKernel, OpRunTensor
55

66

7-
class Range_11(OpRun):
7+
class Range_11(OpRunKernel):
88
"""Range"""
99

1010
@classmethod

onnx_diagnostic/reference/torch_ops/nn_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import onnx
33
import torch
44
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5-
from . import OpRun, OpRunTensor
5+
from . import OpRunKernel, OpRunTensor
66

77

8-
class AveragePool_11(OpRun):
8+
class AveragePool_11(OpRunKernel):
99
"AveragePool"
1010

1111
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
@@ -44,7 +44,7 @@ def run(self, x):
4444
)
4545

4646

47-
class Conv_11(OpRun):
47+
class Conv_11(OpRunKernel):
4848
"Conv"
4949

5050
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
@@ -108,7 +108,7 @@ def run(self, x, w, b=None):
108108
)
109109

110110

111-
class LayerNormalization_17(OpRun):
111+
class LayerNormalization_17(OpRunKernel):
112112
"LayerNormalization"
113113

114114
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
@@ -152,7 +152,7 @@ def run(self, x, scale, bias=None):
152152
)
153153

154154

155-
class Softmax_13(OpRun):
155+
class Softmax_13(OpRunKernel):
156156
"Softmax"
157157

158158
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
@@ -169,7 +169,7 @@ def run(self, data: OpRunTensor) -> OpRunTensor:
169169
)
170170

171171

172-
class Tanh_6(OpRun):
172+
class Tanh_6(OpRunKernel):
173173
"Tanh"
174174

175175
def run(self, data: OpRunTensor) -> OpRunTensor:

0 commit comments

Comments
 (0)