Skip to content

Commit 7053dc2

Browse files
authored
Preparation for a runtime with torch (#115)
* Preparation for a runtime with torch * arg * spelling * to_str * irpy * doc * black
1 parent d316ba6 commit 7053dc2

File tree

6 files changed

+269
-4
lines changed

6 files changed

+269
-4
lines changed

_doc/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def linkcode_resolve(domain, info):
119119
("py:class", "Model"),
120120
("py:class", "Module"),
121121
("py:class", "np.ndarray"),
122-
("py:class", "onnxscript.ir.Tuple"),
122+
("py:class", "onnx_ir.Tuple"),
123123
("py:class", "pipeline.Pipeline"),
124124
("py:class", "torch.fx.passes.operator_support.OperatorSupport"),
125125
("py:class", "torch.fx.proxy.TracerBase"),
@@ -216,6 +216,7 @@ def linkcode_resolve(domain, info):
216216
"monai": "https://monai.io/",
217217
"numpy": "https://numpy.org/",
218218
"onnx": "https://onnx.ai/onnx/",
219+
"onnx-ir": "https://github.com/onnx/ir-py",
219220
"onnx.helper": "https://onnx.ai/onnx/api/helper.html",
220221
"ONNX": "https://onnx.ai/",
221222
"ONNX Operators": "https://onnx.ai/onnx/operators/",
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import unittest
2+
import numpy as np
3+
import onnx
4+
import onnx.helper as oh
5+
import onnx.numpy_helper as onh
6+
import torch
7+
from onnx_diagnostic.ext_test_case import ExtTestCase
8+
from onnx_diagnostic.torch_onnx.runtime_info import (
9+
first_used_last_used,
10+
RuntimeValue,
11+
RuntimeValueKind,
12+
RuntimeDevice,
13+
)
14+
15+
16+
class TestRuntimeInfo(ExtTestCase):
17+
def test_runtime_info(self):
18+
rt = RuntimeValue("e", is_shape=True, value=torch.Tensor([0]))
19+
r = repr(rt)
20+
self.assertEqual("RuntimeValue(name=e, is_shape=True, value=T1s1)", r)
21+
22+
def test_runtime_kind(self):
23+
h = RuntimeValueKind.INPUT
24+
self.assertEqual(h.to_str(), "INPUT")
25+
26+
def test_runtime_device(self):
27+
h = RuntimeDevice.CPU
28+
self.assertEqual(h.to_str(), "CPU")
29+
30+
def test_runtime_values(self):
31+
def _mkv_(name):
32+
value_info_proto = onnx.ValueInfoProto()
33+
value_info_proto.name = name
34+
return value_info_proto
35+
36+
model = oh.make_model(
37+
oh.make_graph(
38+
[
39+
oh.make_node("ReduceSum", ["0X"], ["1Xred"]),
40+
oh.make_node("Add", ["0X", "0two"], ["2X0"]),
41+
oh.make_node("Add", ["2X0", "0zero"], ["3X00"]),
42+
oh.make_node("CastLike", ["0one", "1Xred"], ["4one_c"]),
43+
oh.make_node("Greater", ["1Xred", "4one_c"], ["5cond"]),
44+
oh.make_node(
45+
"If",
46+
["5cond"],
47+
["6Z_c"],
48+
then_branch=oh.make_graph(
49+
[
50+
oh.make_node("Constant", [], ["0two"], value_floats=[2.1]),
51+
oh.make_node("Add", ["3X00", "0two"], ["11Y"]),
52+
],
53+
"then",
54+
[],
55+
[_mkv_("11Y")],
56+
),
57+
else_branch=oh.make_graph(
58+
[
59+
oh.make_node("Constant", [], ["0two"], value_floats=[2.2]),
60+
oh.make_node("Sub", ["2X0", "0two"], ["12Y"]),
61+
],
62+
"else",
63+
[],
64+
[_mkv_("12Y")],
65+
),
66+
),
67+
oh.make_node("CastLike", ["6Z_c", "0X"], ["7Z"]),
68+
],
69+
"test",
70+
[
71+
oh.make_tensor_value_info("0X", onnx.TensorProto.FLOAT, ["N"]),
72+
oh.make_tensor_value_info("0one", onnx.TensorProto.FLOAT, ["N"]),
73+
],
74+
[oh.make_tensor_value_info("7Z", onnx.TensorProto.UNDEFINED, ["N"])],
75+
[
76+
onh.from_array(np.array([0], dtype=np.float32), name="0zero"),
77+
onh.from_array(np.array([2], dtype=np.float32), name="0two"),
78+
],
79+
),
80+
opset_imports=[oh.make_operatorsetid("", 18)],
81+
ir_version=10,
82+
)
83+
rt_values = first_used_last_used(model)
84+
self.assertEqual(
85+
{
86+
"2X0",
87+
"0two",
88+
"5cond",
89+
"1Xred",
90+
"0zero",
91+
"0X",
92+
"4one_c",
93+
"7Z",
94+
"6Z_c",
95+
"0one",
96+
"3X00",
97+
},
98+
set(rt_values),
99+
)
100+
for name, node in rt_values.items():
101+
self.assertEqual(name, node.name)
102+
if name != "7Z":
103+
self.assertIsInstance(node.first_used, int)
104+
self.assertIsInstance(node.last_used, int)
105+
self.assertIsInstance(node.created, int, msg=f"{name!r} missing 'created'")
106+
self.assertIsInstance(node.kind, int)
107+
self.assertEqual(
108+
int(name[0]) - 1, node.created, msg=f"{name!r} created is wrong {node.created}"
109+
)
110+
if name != "7Z":
111+
self.assertGreater(node.first_used, node.created)
112+
self.assertGreaterOrEqual(node.last_used, node.first_used)
113+
114+
115+
if __name__ == "__main__":
116+
unittest.main(verbosity=2)

onnx_diagnostic/helpers/args_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def get_parsed_args(
113113
)
114114

115115
res = parser.parse_args(args=new_args)
116-
update = {}
116+
update: Dict[str, Union[int, float]] = {}
117117
for k, v in res.__dict__.items():
118118
try:
119119
vi = int(v)

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def validate_model(
537537

538538
disc = max_diff(data["expected"], expected)
539539
for k, v in disc.items():
540-
summary[f"disc_patched_{k}"] = v
540+
summary[f"disc_patched_{k}"] = str(v)
541541
if verbose:
542542
print("[validate_model] done (patched run)")
543543
print(f"[validate_model] patched discrepancies={string_diff(disc)}")
@@ -1066,7 +1066,7 @@ def call_torch_export_onnx(
10661066
dynamo=False,
10671067
dynamic_axes={
10681068
k: v
1069-
for k, v in CoupleInputsDynamicShapes(args, kwargs, ds)
1069+
for k, v in CoupleInputsDynamicShapes(args, kwargs, ds) # type: ignore[arg-type]
10701070
.replace_by_string()
10711071
.items()
10721072
if isinstance(v, dict)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import enum
2+
from typing import Any, Dict, Optional, Set, Tuple, Union
3+
import onnx
4+
import torch
5+
from ..helpers import string_type
6+
7+
8+
class RuntimeValueKind(enum.IntEnum):
9+
"Kind of result."
10+
11+
INITIALIZER = 1
12+
INPUT = 2
13+
OUTPUT = 4
14+
RESULT = 12
15+
16+
def to_str(self) -> str:
17+
for k, v in self.__class__.__dict__.items():
18+
if v == int(self):
19+
return k
20+
raise RuntimeError(f"Unable to display {self!r}")
21+
22+
23+
class RuntimeDevice(enum.IntEnum):
24+
"Device definition"
25+
26+
UNKNOWN = 0
27+
NEW = 1
28+
CPU = 2
29+
CUDA = 4
30+
31+
def to_str(self) -> str:
32+
for k, v in self.__class__.__dict__.items():
33+
if v == int(self):
34+
return k
35+
raise RuntimeError(f"Unable to display {self!r}")
36+
37+
38+
class RuntimeValue:
39+
"""Describes a value used during the execution of a model."""
40+
41+
def __init__(
42+
self,
43+
name: str,
44+
dtype: Optional[Any] = None,
45+
shape: Optional[Tuple[Union[str, int], ...]] = None,
46+
value: Optional[torch.Tensor] = None,
47+
first_used: Optional[int] = None,
48+
last_used: Optional[int] = None,
49+
created: Optional[int] = None,
50+
is_shape: Optional[bool] = None,
51+
kind: Optional[RuntimeValueKind] = None,
52+
device: Optional[RuntimeDevice] = None,
53+
):
54+
self.name = name
55+
self.dtype = dtype
56+
self.shape = shape
57+
self.value = value
58+
self.first_used = first_used
59+
self.last_used = last_used
60+
self.created = created
61+
self.is_shape = is_shape
62+
self.kind = kind
63+
self.device = device
64+
65+
def __repr__(self) -> str:
66+
"usual"
67+
ad = {}
68+
for att in [
69+
"name",
70+
"dtype",
71+
"shape",
72+
"first_used",
73+
"last_used",
74+
"is_shape",
75+
"kind",
76+
"created",
77+
"device",
78+
]:
79+
v = getattr(self, att)
80+
if v is not None:
81+
ad[att] = v
82+
if self.value is not None:
83+
ad["value"] = string_type(self.value, with_shape=True)
84+
msg = ", ".join(
85+
f"{name}={t.to_str()}" if hasattr(t, "to_str") else f"{name}={t}"
86+
for name, t in ad.items()
87+
)
88+
return f"{self.__class__.__name__}({msg})"
89+
90+
91+
def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
92+
"""
93+
Returns the hidden inputs (inputs coming from an upper context)
94+
used by a subgraph.
95+
"""
96+
hidden = set()
97+
memo = set(i.name for i in graph.initializer)
98+
memo |= set(i.name for i in graph.sparse_initializer)
99+
for node in graph.node:
100+
for i in node.input:
101+
if i not in memo:
102+
hidden.add(i)
103+
for att in node.attribute:
104+
if att.type == onnx.AttributeProto.GRAPH and att.g:
105+
hid = get_hidden_inputs(att.g)
106+
less = set(h for h in hid if h not in memo)
107+
hidden |= less
108+
memo |= set(node.output)
109+
return hidden
110+
111+
112+
def first_used_last_used(proto: onnx.ModelProto) -> Dict[str, RuntimeValue]:
113+
"""
114+
Builds first used, last used information for every result
115+
in the model.
116+
117+
:param proto: model
118+
:return: dictionary of RuntimeValue
119+
"""
120+
values = {}
121+
for init in proto.graph.initializer:
122+
values[init.name] = RuntimeValue(
123+
init.name, kind=RuntimeValueKind.INITIALIZER, created=-1
124+
)
125+
for init in proto.graph.sparse_initializer:
126+
values[init.name] = RuntimeValue(
127+
init.name, created=-1, kind=RuntimeValueKind.INITIALIZER
128+
)
129+
for inp in proto.graph.input:
130+
values[inp.name] = RuntimeValue(inp.name, created=-1, kind=RuntimeValueKind.INPUT)
131+
for it, node in enumerate(proto.graph.node):
132+
for i in node.input:
133+
if values[i].first_used is None:
134+
values[i].first_used = it
135+
values[i].last_used = it
136+
for att in node.attribute:
137+
if att.type == onnx.AttributeProto.GRAPH:
138+
for n in get_hidden_inputs(att.g):
139+
if values[n].first_used is None:
140+
values[n].first_used = it
141+
values[n].last_used = it
142+
for o in node.output:
143+
values[o] = RuntimeValue(o, created=it, kind=RuntimeValueKind.RESULT)
144+
for out in proto.graph.output:
145+
values[out.name].kind = RuntimeValueKind.OUTPUT
146+
values[out.name].last_used = len(proto.graph.node)
147+
return values

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ huggingface_hub
77
matplotlib
88
onnx-array-api>=0.3.1
99
onnx
10+
git+https://github.com/onnx/ir-py.git
1011
git+https://github.com/microsoft/onnxscript.git
1112
openpyxl
1213
packaging

0 commit comments

Comments
 (0)