Skip to content

Commit 0e9b867

Browse files
committed
changes
1 parent b4f4d54 commit 0e9b867

File tree

6 files changed

+136
-10
lines changed

6 files changed

+136
-10
lines changed

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class TorchOnnxEvaluator:
4444
:param providers: where to run the model
4545
:param opsets: needed if proto is a graph
4646
:param functions: known local functions
47+
:param verbose: verbosity level
4748
4849
The class holds the following attributes:
4950
@@ -56,6 +57,7 @@ class TorchOnnxEvaluator:
5657
* `last_used`: contains the list of intermediate results,
5758
to remove after every node execution,
5859
this avoid the memory to grow too much
60+
* `functions`: local functions
5961
6062
The class is not multithreaded. `runtime_info` gets updated
6163
by the the class. The list of available kernels is returned by function
@@ -68,12 +70,15 @@ def __init__(
6870
providers: Tuple[str, ...] = ("CPUExecutionProvider",),
6971
opsets: Optional[Dict[str, int]] = None,
7072
local_functions: Optional[Dict[Tuple[str, str], "TorchOnnxEvaluator"]] = None,
73+
verbose: int = 0,
7174
):
75+
assert verbose
7276
self.providers = providers
7377
self.constants: Dict[str, torch.Tensor] = {}
7478
self.kernels: List[Optional[torch_ops.OpRun]] = []
7579
self.functions = local_functions.copy() if local_functions else {}
7680
self.CPU = torch.tensor([0]).to("cpu").device
81+
self.verbose = verbose
7782
if "CUDAExecutionProvider" in providers:
7883
self.CUDA = torch.tensor([0]).to("cuda").device
7984
self.default_device = self.CUDA
@@ -87,8 +92,11 @@ def __init__(
8792
assert not proto.graph.sparse_initializer, "sparse_initializer not support yet"
8893
self.opsets = {d.domain: d.version for d in proto.opset_import}
8994
for f in proto.functions:
90-
self.functions[f.domain, f.name] = TorchOnnxEvaluator(
91-
f, providers=providers, local_functions=self.functions
95+
self.functions[f.domain, f.name] = self.__class__(
96+
f,
97+
providers=providers,
98+
local_functions=self.functions,
99+
verbose=self.verbose,
92100
)
93101
self._build_initializers(proto.graph.initializer)
94102
self._build_initializers(proto.graph.node)
@@ -206,22 +214,36 @@ def run(
206214
if not r.has_value:
207215
r.set_value(
208216
torch_ops.OpRunValue(
209-
v.to(self.CUDA) if r.is_shape and self.on_cuda else v, True
217+
v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
218+
is_constant=True,
219+
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
210220
)
211221
)
222+
if self.verbose:
223+
print(f"+C {r.name}: {r.string_type()}")
212224

213225
# inputs
214226
for k, v in feeds.items():
215227
r = self.runtime_info[k]
216228
r.set_value(
217229
torch_ops.OpRunValue(
218-
v.to(self.CUDA) if r.is_shape and self.on_cuda else v, False
230+
v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
231+
is_constant=False,
232+
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
219233
)
220234
)
235+
if self.verbose:
236+
print(f"+I {r.name}: {r.string_type()}")
221237

222238
# node execution
223239
for it, kernel in enumerate(self.kernels):
224240
if kernel is not None:
241+
if self.verbose:
242+
print(
243+
f"{kernel.__class__.__name__}"
244+
f"({', '.join(kernel.input)}) -> "
245+
f"{', '.join(kernel.output)}"
246+
)
225247
# kernel execution
226248
inputs = [(self.runtime_info[i].value if i else None) for i in kernel.input]
227249
if kernel.has_subgraphs():
@@ -236,26 +258,42 @@ def run(
236258
)
237259
for name, t in zip(kernel.output, res):
238260
self.runtime_info[name].set_value(t)
261+
if self.verbose:
262+
for name in kernel.output:
263+
print(f"+R {name}: {self.runtime_info[name].string_type()}")
239264
else:
240265
assert isinstance(
241266
res, torch_ops.OpRunValue
242267
), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
243268
self.runtime_info[kernel.output[0]].set_value(res)
269+
if self.verbose:
270+
print(
271+
f"+R {kernel.output[0]}: "
272+
f"{self.runtime_info[kernel.output[0]].string_type()}"
273+
)
244274

245275
# free intermediate results
246276
for name in self.last_used[it]:
247277
self.runtime_info[name].clean_value()
278+
if self.verbose:
279+
print(f"- clean {name}")
248280

249281
assert all(
250282
self.runtime_info[o].value is not None for o in outputs
251283
), "Not implemented yet when one output is None."
252284
fres = [self.runtime_info[o].value.tensor for o in outputs] # type: ignore[union-attr]
285+
if self.verbose:
286+
print(f"++ outputs {', '.join(outputs)}")
253287

254288
# clean previous execution
255289
for k in feeds:
256290
self.runtime_info[k].clean_value()
291+
if self.verbose:
292+
print(f"- clean {k}")
257293
for o in outputs:
258294
self.runtime_info[o].clean_value()
295+
if self.verbose:
296+
print(f"- clean {o}")
259297

260298
if use_numpy:
261299
return [None if a is None else a.detach().cpu().numpy() for a in fres]
@@ -285,7 +323,9 @@ def run_with_values(
285323
if not r.has_value:
286324
r.set_value(
287325
torch_ops.OpRunValue(
288-
v.to(self.CUDA) if r.is_shape and self.on_cuda else v, True
326+
v.to(self.CUDA) if r.is_shape is False and self.on_cuda else v,
327+
is_constant=True,
328+
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
289329
)
290330
)
291331

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,32 @@ class OpRunValue:
1111
1212
:param tensor: torch.Tensor
1313
:param is_constant: is it a constant
14+
:param may_cpu: change the device the tensor is if
15+
more appropriate
1416
"""
1517

1618
__slots__ = ("cached", "is_constant", "tensor")
1719

18-
def __init__(self, tensor, is_constant: bool = False):
19-
self.tensor = tensor
20+
def __init__(self, tensor, is_constant: bool = False, may_cpu: bool = False):
21+
self.tensor = (
22+
tensor.cpu()
23+
if may_cpu
24+
and len(tensor.shape) == 1
25+
and tensor.numel() < 8
26+
and tensor.dtype == torch.int64
27+
and tensor.get_device() >= 0
28+
else tensor
29+
)
2030
self.is_constant = is_constant
2131
self.cached: Optional[Tuple[int, ...]] = None
2232

33+
def string_type(self) -> str:
34+
"Returns informations about the value as a string."
35+
s = string_type(self.tensor, with_shape=True, with_min_max=True, with_device=True)
36+
if self.is_constant:
37+
return f"CST({s})"
38+
return s
39+
2340
def __repr__(self) -> str:
2441
"usual"
2542
if self.is_constant:
@@ -42,6 +59,19 @@ def dtype(self):
4259
def _tensor_as_tuple_int(self) -> Tuple[int, ...]:
4360
return tuple(map(int, self.tensor))
4461

62+
def numel(self) -> int:
63+
"Returns the number of elements."
64+
return 0 if self.tensor is None else self.tensor.numel()
65+
66+
def get_device(self) -> int:
67+
"Returns the device id."
68+
return -1 if self.tensor is None else self.tensor.get_device()
69+
70+
@property
71+
def device(self):
72+
"Returns the device."
73+
return -1 if self.tensor is None else self.tensor.device
74+
4575
@property
4676
def as_tuple_int(self) -> Tuple[int, ...]:
4777
"value as int"

onnx_diagnostic/reference/torch_ops/control_flow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
providers=parent.providers,
2929
opsets=parent.opsets,
3030
local_functions=parent.functions,
31+
verbose=parent.verbose,
3132
)
3233
setattr(self, att.name, rt)
3334

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,35 @@
1+
from typing import Optional
2+
import onnx
13
import torch
24
from . import OpRun, OpRunValue
35

46

57
class Range_11(OpRun):
68
"""Range"""
79

10+
@classmethod
11+
def device_dependent(cls) -> bool:
12+
"""
13+
Returns True if the kernel needs a device to be efficiently initialized.
14+
"""
15+
return True
16+
17+
def __init__(
18+
self,
19+
node: onnx.NodeProto,
20+
version: Optional[int] = None,
21+
device: Optional[torch.device] = None,
22+
):
23+
super().__init__(node, version)
24+
self.device = device
25+
826
def run(self, starts: OpRunValue, limit: OpRunValue, delta: OpRunValue) -> OpRunValue:
927
return OpRunValue(
10-
torch.arange(starts.tensor, limit.tensor, delta.tensor, dtype=starts.dtype)
28+
torch.arange(
29+
starts.tensor,
30+
limit.tensor,
31+
delta.tensor,
32+
dtype=starts.dtype,
33+
device=self.device,
34+
)
1135
)

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,11 @@ def _mk(key):
968968
)
969969
)
970970
if runtime == "onnxruntime"
971-
else (lambda model, providers: TorchOnnxEvaluator(model, providers=providers))
971+
else (
972+
lambda model, providers: TorchOnnxEvaluator(
973+
model, providers=providers, verbose=max(verbose - 1, 0)
974+
)
975+
)
972976
)
973977
sess = _quiet_or_not_quiet(
974978
quiet,

onnx_diagnostic/torch_onnx/runtime_info.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ def __repr__(self) -> str:
8080
if v is not None:
8181
ad[att] = v
8282
if self.value is not None:
83-
ad["value"] = string_type(self.value, with_shape=True)
83+
ad["value"] = (
84+
self.value.string_type()
85+
if hasattr(self.value, "string_type")
86+
else string_type(self.value, with_shape=True)
87+
)
8488
msg = ", ".join(
8589
f"{name}={t.to_str()}" if hasattr(t, "to_str") else f"{name}={t}"
8690
for name, t in ad.items()
@@ -92,9 +96,32 @@ def has_value(self) -> bool:
9296
"Tells if value is specified."
9397
return self.value is not None
9498

99+
def string_type(self) -> str:
100+
"Returns a string describing the value."
101+
rows = []
102+
if self.shape is not None:
103+
rows.append(f"shape={self.shape}")
104+
if self.is_shape is not None:
105+
rows.append(f"is_shape={self.is_shape}")
106+
if self.device is not None:
107+
rows.append(f"device={self.device}")
108+
text = f", {', '.join(rows)}" if rows else ""
109+
if self.value is None:
110+
return (
111+
f"RuntimeValue(name={self.name!r}{text}"
112+
f", dtype={self.dtype}, kind={self.kind})"
113+
)
114+
return (
115+
f"RuntimeValue(name={self.name!r}, "
116+
f"kind={self.kind}{text}, value={self.value.string_type()})"
117+
)
118+
95119
def set_value(self, value: torch.Tensor):
96120
"""Sets the value."""
97121
assert value is not None, "Use clean_value to set a value to None"
122+
assert (
123+
self.name != "position_ids" or value.get_device() >= 0
124+
), f"{value} - is_shape={self.is_shape}"
98125
self.value = value
99126
if self.dtype:
100127
assert (

0 commit comments

Comments
 (0)