Skip to content

Commit 11434ab

Browse files
committed
fix iterate
1 parent 73fba13 commit 11434ab

File tree

6 files changed

+188
-30
lines changed

6 files changed

+188
-30
lines changed

_unittests/ut_helpers/test_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,18 @@
2828
get_onnx_signature,
2929
type_info,
3030
onnx_dtype_name,
31-
onnx_dtype_to_torch_dtype,
3231
onnx_dtype_to_np_dtype,
3332
np_dtype_to_tensor_dtype,
34-
torch_dtype_to_onnx_dtype,
3533
from_array_extended,
3634
to_array_extended,
3735
convert_endian,
3836
from_array_ml_dtypes,
3937
dtype_to_tensor_dtype,
4038
)
39+
from onnx_diagnostic.helpers.torch_helper import (
40+
onnx_dtype_to_torch_dtype,
41+
torch_dtype_to_onnx_dtype,
42+
)
4143
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
4244
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
4345

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import unittest
2+
from typing import Any, Dict, List
23
import numpy as np
34
import onnx.helper as oh
45
import onnx.numpy_helper as onh
5-
from onnx import TensorProto
6+
from onnx import TensorProto, FunctionProto, ValueInfoProto
67
from onnx.checker import check_model
8+
import torch
79
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
810
from onnx_diagnostic.helpers.onnx_helper import (
911
onnx_lighten,
1012
onnx_unlighten,
1113
onnx_find,
1214
_validate_function,
1315
check_model_ort,
16+
iterator_initializer_constant,
17+
from_array_extended,
1418
)
1519

1620

@@ -122,6 +126,122 @@ def test_check_model_ort(self):
122126
)
123127
check_model_ort(model)
124128

129+
def test_iterate_init(self):
130+
itype = TensorProto.FLOAT
131+
cst = np.arange(6).astype(np.float32)
132+
model = oh.make_model(
133+
oh.make_graph(
134+
[
135+
oh.make_node("IsNaN", ["x"], ["xi"]),
136+
oh.make_node("IsNaN", ["y"], ["yi"]),
137+
oh.make_node("Cast", ["xi"], ["xii"], to=TensorProto.INT64),
138+
oh.make_node("Cast", ["yi"], ["yii"], to=TensorProto.INT64),
139+
oh.make_node("Add", ["xii", "yii"], ["gggg"]),
140+
oh.make_node("Cast", ["gggg"], ["final"], to=itype),
141+
],
142+
"dummy",
143+
[oh.make_tensor_value_info("x", itype, [None, None])],
144+
[oh.make_tensor_value_info("final", itype, [None, None])],
145+
[from_array_extended(cst, name="y")],
146+
),
147+
opset_imports=[oh.make_opsetid("", 20)],
148+
ir_version=10,
149+
)
150+
li = list(iterator_initializer_constant(model))
151+
self.assertEqual(len(li), 1)
152+
self.assertEqual(li[0][0], "y")
153+
self.assertEqualArray(li[0][1], cst)
154+
li = list(iterator_initializer_constant(model, use_numpy=False))
155+
self.assertEqual(len(li), 1)
156+
self.assertEqual(li[0][0], "y")
157+
self.assertEqualArray(li[0][1], cst)
158+
self.assertIsInstance(li[0][1], torch.Tensor)
159+
160+
def _get_cdist_implementation(
161+
self,
162+
node_inputs: List[str],
163+
node_outputs: List[str],
164+
opsets: Dict[str, int],
165+
**kwargs: Any,
166+
) -> FunctionProto:
167+
"""
168+
Returns the CDist implementation as a function.
169+
"""
170+
assert len(node_inputs) == 2
171+
assert len(node_outputs) == 1
172+
assert opsets
173+
assert "" in opsets
174+
assert set(kwargs) == {"metric"}, f"kwargs={kwargs}"
175+
metric = kwargs["metric"]
176+
assert metric in ("euclidean", "sqeuclidean")
177+
# subgraph
178+
nodes = [
179+
oh.make_node("Sub", ["next", "next_in"], ["diff"]),
180+
oh.make_node("Constant", [], ["axis"], value_ints=[1]),
181+
oh.make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0),
182+
oh.make_node("Identity", ["next_in"], ["next_out"]),
183+
]
184+
185+
def make_value(name):
186+
value = ValueInfoProto()
187+
value.name = name
188+
return value
189+
190+
graph = oh.make_graph(
191+
nodes,
192+
"loop",
193+
[make_value("next_in"), make_value("next")],
194+
[make_value("next_out"), make_value("scan_out")],
195+
)
196+
197+
scan = oh.make_node(
198+
"Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph
199+
)
200+
final = (
201+
oh.make_node("Sqrt", ["zout"], ["z"])
202+
if metric == "euclidean"
203+
else oh.make_node("Identity", ["zout"], ["z"])
204+
)
205+
return oh.make_function(
206+
"npx",
207+
f"CDist_{metric}",
208+
["xa", "xb"],
209+
["z"],
210+
[scan, final],
211+
[oh.make_opsetid("", opsets[""])],
212+
)
213+
214+
def test_iterate_function(self):
215+
itype = TensorProto.FLOAT
216+
proto = self._get_cdist_implementation(
217+
["X", "Y"], ["Z"], opsets={"": 18}, metric="euclidean"
218+
)
219+
model = oh.make_model(
220+
oh.make_graph(
221+
[
222+
oh.make_node(proto.name, ["X", "Y"], ["Z"]),
223+
],
224+
"dummy",
225+
[
226+
oh.make_tensor_value_info("X", itype, [None, None]),
227+
oh.make_tensor_value_info("Y", itype, [None, None]),
228+
],
229+
[oh.make_tensor_value_info("final", itype, [None, None])],
230+
),
231+
opset_imports=[oh.make_opsetid("", 18)],
232+
ir_version=10,
233+
)
234+
model.functions.append(proto)
235+
li = list(iterator_initializer_constant(model))
236+
self.assertEqual(len(li), 1)
237+
self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis")
238+
self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64))
239+
li = list(iterator_initializer_constant(model, use_numpy=False))
240+
self.assertEqual(len(li), 1)
241+
self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis")
242+
self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64))
243+
self.assertIsInstance(li[0][1], torch.Tensor)
244+
125245

126246
if __name__ == "__main__":
127247
unittest.main(verbosity=2)

_unittests/ut_helpers/test_ort_session.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@
1212
requires_onnxruntime_training,
1313
requires_cuda,
1414
)
15-
from onnx_diagnostic.helpers.onnx_helper import (
16-
from_array_extended,
17-
onnx_dtype_to_np_dtype,
18-
onnx_dtype_to_torch_dtype,
19-
)
15+
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, onnx_dtype_to_np_dtype
16+
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
2017
from onnx_diagnostic.helpers.ort_session import (
2118
InferenceSessionForNumpy,
2219
InferenceSessionForTorch,

_unittests/ut_reference/test_ort_evaluator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@
1414
ignore_warnings,
1515
requires_cuda,
1616
)
17-
from onnx_diagnostic.helpers.onnx_helper import (
18-
from_array_extended,
19-
onnx_dtype_to_torch_dtype,
20-
onnx_dtype_to_np_dtype,
21-
)
17+
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, onnx_dtype_to_np_dtype
18+
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
2219
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
2320
from onnx_diagnostic.helpers.ort_session import _InferenceSession
2421

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import os
44
import sys
5-
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
5+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
66
import numpy as np
77
import numpy.typing as npt
88
import onnx
@@ -762,3 +762,59 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
762762
return mapping[tensor_dtype]
763763

764764
return oh.tensor_dtype_to_np_dtype(tensor_dtype)
765+
766+
767+
def iterator_initializer_constant(
768+
model: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
769+
use_numpy: bool = True,
770+
prefix: str = "",
771+
) -> Iterator[Tuple[str, Union["torch.Tensor", np.ndarray]]]: # noqa: F821
772+
"""
773+
Iterates on iniatialiers and constant in an onnx model.
774+
775+
:param model: model
776+
:param use_numpy: use numpy or pytorch
777+
:param prefix: for subgraph
778+
:return: iterator
779+
"""
780+
if not isinstance(model, onnx.FunctionProto):
781+
graph = model if isinstance(model, onnx.GraphProto) else model.graph
782+
if not use_numpy:
783+
from .torch_helper import to_tensor
784+
if prefix:
785+
prefix += "."
786+
for init in graph.initializer:
787+
yield f"{prefix}{init.name}", (
788+
to_array_extended(init) if use_numpy else to_tensor(init)
789+
)
790+
nodes = graph.node
791+
name = graph.name
792+
if isinstance(model, onnx.ModelProto):
793+
for f in model.functions:
794+
yield from iterator_initializer_constant(
795+
f, use_numpy=use_numpy, prefix=f"{prefix}{f.name}"
796+
)
797+
else:
798+
nodes = model.node
799+
name = model.name
800+
for node in nodes:
801+
if node.op_type == "Constant" and node.domain == "":
802+
from ..reference import ExtendedReferenceEvaluator as Inference
803+
804+
if not use_numpy:
805+
import torch
806+
sess = Inference(node)
807+
value = sess.run(None, {})[0]
808+
yield f"{prefix}{node.output[0]}", (
809+
value if use_numpy else torch.from_numpy(value)
810+
)
811+
812+
if node.op_type in {"Loop", "Body", "Scan"}:
813+
for att in node.attribute:
814+
assert (
815+
att.type != onnx.AttributeProto.GRAPHS
816+
), "Not implemented for type AttributeProto.GRAPHS."
817+
if att.type == onnx.AttributeProto.GRAPH:
818+
yield from iterator_initializer_constant(
819+
att.g, use_numpy=use_numpy, prefix=f"{prefix}{name}"
820+
)

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import warnings
77
from collections.abc import Iterable
8-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
8+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
99
import numpy as np
1010
import onnx
1111
from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
@@ -858,19 +858,5 @@ def to_tensor(tensor: onnx.TensorProto, base_dir: str = "") -> torch.Tensor:
858858
return torch.frombuffer(raw_data, dtype=torch_dtype).reshape(dims)
859859

860860
# Other cases, it should be small tensor. We use numpy.
861-
np_tensor = to_array_extended(tensor, base_dir)
861+
np_tensor = to_array_extended(tensor)
862862
return torch.from_numpy(np_tensor)
863-
864-
865-
def iterator_initializer_constant(
866-
model: onnx.ModelProto, use_numpy: bool = True
867-
) -> Iterator[Tuple[str, Union[torch.Tensor, np.ndarray]]]:
868-
"""
869-
Iterates on iniatialiers and constant in an onnx model.
870-
871-
:param model: model
872-
:param use_numpy: use numpy or pytorch
873-
:return: iterator
874-
"""
875-
for init in model.graph.initializer:
876-
yield init.name, (to_array_extended(init) if use_numpy else to_tensor(init))

0 commit comments

Comments
 (0)