Skip to content

Commit 68f5d4a

Browse files
committed
refactoring
1 parent db3b67c commit 68f5d4a

File tree

19 files changed

+226
-175
lines changed

19 files changed

+226
-175
lines changed

_doc/api/api.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.api
3+
===================
4+
5+
.. automodule:: onnx_diagnostic.api
6+
:members:
7+
:no-undoc-members:

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ API of onnx_diagnostic
1919
:maxdepth: 1
2020
:caption: modules
2121

22+
api
2223
ext_test_case
2324

2425
.. automodule:: onnx_diagnostic

_doc/api/reference/torch_ops/index.rst

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,22 @@ OpRun
2424
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRun
2525
:members:
2626

27+
OpRunTensor
28+
+++++++++++
29+
30+
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunTensor
31+
:members:
32+
2733
OpRunValue
2834
++++++++++
2935

3036
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunValue
3137
:members:
3238

33-
OpRunValueSequence
34-
++++++++++++++++++
39+
OpRunSequence
40+
+++++++++++++
3541

36-
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunValueSequence
42+
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunSequence
3743
:members:
3844

3945
OpRunFunction

onnx_diagnostic/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class TensorLike:
2+
"""Mocks a tensor."""

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def run(
264264
r = self.runtime_info[k]
265265
if not r.has_value:
266266
r.set_value(
267-
torch_ops.OpRunValue(
267+
torch_ops.OpRunTensor(
268268
v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
269269
is_constant=True,
270270
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
@@ -277,7 +277,7 @@ def run(
277277
for k, v in feeds.items():
278278
r = self.runtime_info[k]
279279
r.set_value(
280-
torch_ops.OpRunValue(
280+
torch_ops.OpRunTensor(
281281
v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
282282
is_constant=False,
283283
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
@@ -352,15 +352,15 @@ def run(
352352

353353
def run_with_values(
354354
self,
355-
*args: Optional[torch_ops.OpRunValue],
355+
*args: Optional[torch_ops.OpRunTensor],
356356
context: Optional[Dict[str, RuntimeValue]] = None,
357357
) -> Union[torch_ops.OpRunValue, Tuple[torch_ops.OpRunValue, ...]]:
358358
"""
359359
Runs the ONNX model.
360360
361361
:param args: inputs
362362
:param context: local context for the execution of subgraphs
363-
:return: output OpRunValue
363+
:return: output OpRunTensor
364364
"""
365365
assert all(
366366
isinstance(a, torch_ops.OpRunValue) for a in args
@@ -373,7 +373,7 @@ def run_with_values(
373373
r = self.runtime_info[k]
374374
if not r.has_value:
375375
r.set_value(
376-
torch_ops.OpRunValue(
376+
torch_ops.OpRunTensor(
377377
v.to(self.CUDA) if r.is_shape is False and self.on_cuda else v,
378378
is_constant=True,
379379
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
@@ -384,7 +384,7 @@ def run_with_values(
384384
for k, v in zip(self.input_names, args):
385385
r = self.runtime_info[k]
386386
r.set_value(
387-
torch_ops.OpRunValue(None) if v is None else v.__class__(v.tensor_or_sequence)
387+
torch_ops.OpRunTensor(None) if v is None else v.__class__(v.tensor_or_sequence)
388388
)
389389

390390
# node execution
@@ -406,7 +406,7 @@ def run_with_values(
406406
res = kernel.run(*inputs)
407407
if isinstance(res, tuple):
408408
# outputs
409-
assert all(isinstance(o, torch_ops.OpRunValue) for o in res), (
409+
assert all(isinstance(o, torch_ops.OpRunTensor) for o in res), (
410410
f"Unexpected output type {[type(o) for o in res]} "
411411
f"for kernel {type(kernel)}."
412412
)
@@ -425,7 +425,7 @@ def run_with_values(
425425
assert all(
426426
self.runtime_info[o].value is not None for o in outputs
427427
), "Not implemented yet when one output is None."
428-
res2 = [torch_ops.OpRunValue(self.runtime_info[o].value.tensor) for o in outputs] # type: ignore[assignment, union-attr]
428+
res2 = [self.runtime_info[o].value.copy() for o in outputs] # type: ignore[assignment, union-attr]
429429

430430
# clean previous execution
431431
for k in self.input_names:

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._op_run import OpRun, OpRunFunction, OpRunValue, OpRunValueSequence
1+
from ._op_run import OpRun, OpRunFunction, OpRunSequence, OpRunTensor, OpRunValue
22
from .access_ops import Gather_1, Slice_13
33
from .binary_ops import (
44
And_1,

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
from typing import Any, List, Optional, Union, Tuple
22
import onnx
33
import torch
4+
from ...api import TensorLike
45
from ...helpers import string_type
56
from ...helpers.torch_helper import to_tensor
67

78

8-
class OpRunValue:
9+
class OpRunValue(TensorLike):
10+
"""Defines a value for the runtime, a tensor or a sequence."""
11+
12+
__slots__ = ("cached", "is_constant", "sequence", "tensor")
13+
14+
@classmethod
15+
def is_sequence(cls) -> bool:
16+
"Tells if it is sequence."
17+
raise NotImplementedError("is_sequence must be overwritten.")
18+
19+
20+
class OpRunTensor(OpRunValue):
921
"""
1022
Wrapper around a tensor.
1123
@@ -15,10 +27,9 @@ class OpRunValue:
1527
more appropriate
1628
"""
1729

18-
__slots__ = ("cached", "is_constant", "tensor")
19-
2030
def __init__(self, tensor, is_constant: bool = False, may_cpu: bool = False):
2131
assert isinstance(tensor, torch.Tensor), f"Unexpected type {type(tensor)}"
32+
assert tensor is None or tensor.numel() > 1 or tensor.item() != -666666
2233
self.tensor = (
2334
tensor.cpu()
2435
if may_cpu
@@ -36,9 +47,9 @@ def is_sequence(cls) -> bool:
3647
"Tells if it is sequence."
3748
return False
3849

39-
def to(self, to: Any) -> "OpRunValue":
50+
def to(self, to: Any) -> "OpRunTensor":
4051
"Changes the device."
41-
return OpRunValue(self.tensor.to(to))
52+
return OpRunTensor(self.tensor.to(to))
4253

4354
def string_type(self) -> str:
4455
"Returns information about the value as a string."
@@ -96,17 +107,29 @@ def as_tuple_int(self) -> Tuple[int, ...]:
96107
return self.cached
97108
return self._tensor_as_tuple_int()
98109

110+
def copy(self) -> "OpRunTensor":
111+
"Shallow copy."
112+
return self.__class__(self.tensor)
99113

100-
class OpRunValueSequence(OpRunValue):
101-
"""Defines a sequence."""
102114

103-
__slots__ = ("cached", "is_constant", "sequence", "tensor")
115+
class OpRunSequence(OpRunValue):
116+
"""Defines a sequence."""
104117

105118
def __init__(
106119
self, sequence: Optional[List[torch.Tensor]] = None, dtype: torch.dtype = torch.float32
107120
):
108-
super().__init__(torch.empty((), dtype=dtype), False, False)
121+
self.tensor = torch.tensor(-666666, dtype=dtype)
122+
self.is_shape = False
109123
self.sequence = sequence or []
124+
self.cached: Optional[Tuple[int, ...]] = None
125+
assert all(
126+
isinstance(s, torch.Tensor) for s in self.sequence
127+
), f"Unexpected type in sequence {[type(s) for s in self.sequence]}"
128+
129+
@property
130+
def dtype(self):
131+
"dtype (torch dtype)"
132+
return self.tensor.dtype
110133

111134
@property
112135
def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]:
@@ -119,18 +142,23 @@ def is_sequence(cls) -> bool:
119142
return True
120143

121144
def insert_at(
122-
self, tensor: OpRunValue, position: Optional[OpRunValue] = None
123-
) -> "OpRunValueSequence":
145+
self, tensor: torch.Tensor, position: Optional[OpRunTensor] = None
146+
) -> "OpRunSequence":
124147
"Inserts a value at a given position."
125-
new_seq = OpRunValueSequence()
148+
assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
149+
new_seq = OpRunSequence()
126150
seq = self.sequence.copy()
127151
new_seq.sequence = seq
128152
if position is None:
129-
seq.append(tensor)
153+
seq.append(tensor.tensor)
130154
else:
131-
seq.insert(int(position.tensor.item()), tensor)
155+
seq.insert(int(position.tensor.item()), tensor.tensor)
132156
return new_seq
133157

158+
def copy(self) -> "OpRunSequence":
159+
"Shallow copy."
160+
return self.__class__(self.sequence, dtype=self.dtype)
161+
134162

135163
class OpRun:
136164
"""

onnx_diagnostic/reference/torch_ops/access_ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22
import onnx
33
import torch
4-
from . import OpRun, OpRunValue
4+
from . import OpRun, OpRunTensor
55

66

77
class Gather_1(OpRun):
@@ -18,20 +18,20 @@ def run(self, x, indices):
1818
return torch.empty((0,), dtype=x.tensor.dtype, device=x.tensor.device)
1919
ind = [slice(0, s) for s in x.shape]
2020
ind[self.axis] = indices.tensor
21-
return OpRunValue(x.tensor[tuple(ind)])
21+
return OpRunTensor(x.tensor[tuple(ind)])
2222

2323

2424
class Slice_13(OpRun):
2525
"Slice"
2626

2727
def run(
2828
self,
29-
data: OpRunValue,
30-
starts: OpRunValue,
31-
ends: OpRunValue,
32-
axes: Optional[OpRunValue] = None,
33-
steps: Optional[OpRunValue] = None,
34-
) -> OpRunValue:
29+
data: OpRunTensor,
30+
starts: OpRunTensor,
31+
ends: OpRunTensor,
32+
axes: Optional[OpRunTensor] = None,
33+
steps: Optional[OpRunTensor] = None,
34+
) -> OpRunTensor:
3535
if axes is None:
3636
if steps is None:
3737
slices = [slice(s, e) for s, e in zip(starts.tensor, ends.tensor)]
@@ -48,4 +48,4 @@ def run(
4848
slices = [slice(0, a) for a in data.shape]
4949
for s, e, a, d in zip(starts.tensor, ends.tensor, axes.tensor, steps.tensor):
5050
slices[a] = slice(s, e, d)
51-
return OpRunValue(data.tensor[tuple(slices)])
51+
return OpRunTensor(data.tensor[tuple(slices)])

0 commit comments

Comments
 (0)