Skip to content

Commit 7936215

Browse files
authored
Add to_tensor (#98)
* Add to_tensor * fix iterate * fix import * fix issues * spell * tiny modif * fix device
1 parent e151311 commit 7936215

26 files changed

+699
-214
lines changed

_doc/api/helpers/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ onnx_diagnostic.helpers
1616
onnx_helper
1717
ort_session
1818
rt_helper
19-
torch_test_helper
19+
torch_helper
2020

2121
.. autofunction:: onnx_diagnostic.helpers.flatten_object
2222

_doc/api/helpers/torch_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.torch_helper
3+
====================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.torch_helper
6+
:members:
7+
:no-undoc-members:

_doc/api/helpers/torch_test_helper.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

_doc/examples/plot_export_tiny_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import transformers
3232
from onnx_diagnostic import doc
3333
from onnx_diagnostic.helpers import string_type
34-
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
34+
from onnx_diagnostic.helpers.torch_helper import steal_forward
3535
from onnx_diagnostic.torch_models.llms import get_tiny_llm
3636

3737

@@ -77,7 +77,7 @@ def _forward_(*args, _f=None, **kwargs):
7777
model.forward = keep_model_forward
7878

7979
# %%
80-
# Another syntax with :func:`onnx_diagnostic.helpers.torch_test_helper.steal_forward`.
80+
# Another syntax with :func:`onnx_diagnostic.helpers.torch_helper.steal_forward`.
8181

8282
with steal_forward(model):
8383
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)

_unittests/ut_export/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
requires_onnxscript,
99
)
1010
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
11-
from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting
11+
from onnx_diagnostic.helpers.torch_helper import is_torchdynamo_exporting
1212

1313
try:
1414
from experimental_experiment.torch_interpreter import to_onnx

_unittests/ut_helpers/test_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,18 @@
2828
get_onnx_signature,
2929
type_info,
3030
onnx_dtype_name,
31-
onnx_dtype_to_torch_dtype,
3231
onnx_dtype_to_np_dtype,
3332
np_dtype_to_tensor_dtype,
34-
torch_dtype_to_onnx_dtype,
3533
from_array_extended,
3634
to_array_extended,
3735
convert_endian,
3836
from_array_ml_dtypes,
3937
dtype_to_tensor_dtype,
4038
)
39+
from onnx_diagnostic.helpers.torch_helper import (
40+
onnx_dtype_to_torch_dtype,
41+
torch_dtype_to_onnx_dtype,
42+
)
4143
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
4244
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
4345

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
11
import unittest
2+
from typing import Any, Dict, List
23
import numpy as np
34
import onnx.helper as oh
45
import onnx.numpy_helper as onh
5-
from onnx import TensorProto
6+
from onnx import TensorProto, FunctionProto, ValueInfoProto
67
from onnx.checker import check_model
8+
import torch
79
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
810
from onnx_diagnostic.helpers.onnx_helper import (
911
onnx_lighten,
1012
onnx_unlighten,
1113
onnx_find,
1214
_validate_function,
1315
check_model_ort,
16+
iterator_initializer_constant,
17+
from_array_extended,
18+
tensor_statistics,
1419
)
1520

1621

1722
TFLOAT = TensorProto.FLOAT
1823

1924

20-
class TestOnnxTools(ExtTestCase):
25+
class TestOnnxHelper(ExtTestCase):
2126

2227
def _get_model(self):
2328
model = oh.make_model(
@@ -122,6 +127,130 @@ def test_check_model_ort(self):
122127
)
123128
check_model_ort(model)
124129

130+
def test_iterate_init(self):
131+
itype = TensorProto.FLOAT
132+
cst = np.arange(6).astype(np.float32)
133+
model = oh.make_model(
134+
oh.make_graph(
135+
[
136+
oh.make_node("IsNaN", ["x"], ["xi"]),
137+
oh.make_node("IsNaN", ["y"], ["yi"]),
138+
oh.make_node("Cast", ["xi"], ["xii"], to=TensorProto.INT64),
139+
oh.make_node("Cast", ["yi"], ["yii"], to=TensorProto.INT64),
140+
oh.make_node("Add", ["xii", "yii"], ["gggg"]),
141+
oh.make_node("Cast", ["gggg"], ["final"], to=itype),
142+
],
143+
"dummy",
144+
[oh.make_tensor_value_info("x", itype, [None, None])],
145+
[oh.make_tensor_value_info("final", itype, [None, None])],
146+
[from_array_extended(cst, name="y")],
147+
),
148+
opset_imports=[oh.make_opsetid("", 20)],
149+
ir_version=10,
150+
)
151+
li = list(iterator_initializer_constant(model))
152+
self.assertEqual(len(li), 1)
153+
self.assertEqual(li[0][0], "y")
154+
self.assertEqualArray(li[0][1], cst)
155+
li = list(iterator_initializer_constant(model, use_numpy=False))
156+
self.assertEqual(len(li), 1)
157+
self.assertEqual(li[0][0], "y")
158+
self.assertEqualArray(li[0][1], cst)
159+
self.assertIsInstance(li[0][1], torch.Tensor)
160+
161+
def _get_cdist_implementation(
162+
self,
163+
node_inputs: List[str],
164+
node_outputs: List[str],
165+
opsets: Dict[str, int],
166+
**kwargs: Any,
167+
) -> FunctionProto:
168+
"""
169+
Returns the CDist implementation as a function.
170+
"""
171+
assert len(node_inputs) == 2
172+
assert len(node_outputs) == 1
173+
assert opsets
174+
assert "" in opsets
175+
assert set(kwargs) == {"metric"}, f"kwargs={kwargs}"
176+
metric = kwargs["metric"]
177+
assert metric in ("euclidean", "sqeuclidean")
178+
# subgraph
179+
nodes = [
180+
oh.make_node("Sub", ["next", "next_in"], ["diff"]),
181+
oh.make_node("Constant", [], ["axis"], value_ints=[1]),
182+
oh.make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0),
183+
oh.make_node("Identity", ["next_in"], ["next_out"]),
184+
]
185+
186+
def make_value(name):
187+
value = ValueInfoProto()
188+
value.name = name
189+
return value
190+
191+
graph = oh.make_graph(
192+
nodes,
193+
"loop",
194+
[make_value("next_in"), make_value("next")],
195+
[make_value("next_out"), make_value("scan_out")],
196+
)
197+
198+
scan = oh.make_node(
199+
"Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph
200+
)
201+
final = (
202+
oh.make_node("Sqrt", ["zout"], ["z"])
203+
if metric == "euclidean"
204+
else oh.make_node("Identity", ["zout"], ["z"])
205+
)
206+
return oh.make_function(
207+
"npx",
208+
f"CDist_{metric}",
209+
["xa", "xb"],
210+
["z"],
211+
[scan, final],
212+
[oh.make_opsetid("", opsets[""])],
213+
)
214+
215+
def test_iterate_function(self):
216+
itype = TensorProto.FLOAT
217+
proto = self._get_cdist_implementation(
218+
["X", "Y"], ["Z"], opsets={"": 18}, metric="euclidean"
219+
)
220+
model = oh.make_model(
221+
oh.make_graph(
222+
[
223+
oh.make_node(proto.name, ["X", "Y"], ["Z"]),
224+
],
225+
"dummy",
226+
[
227+
oh.make_tensor_value_info("X", itype, [None, None]),
228+
oh.make_tensor_value_info("Y", itype, [None, None]),
229+
],
230+
[oh.make_tensor_value_info("final", itype, [None, None])],
231+
),
232+
opset_imports=[oh.make_opsetid("", 18)],
233+
ir_version=10,
234+
)
235+
model.functions.append(proto)
236+
li = list(iterator_initializer_constant(model))
237+
self.assertEqual(len(li), 1)
238+
self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis")
239+
self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64))
240+
li = list(iterator_initializer_constant(model, use_numpy=False))
241+
self.assertEqual(len(li), 1)
242+
self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis")
243+
self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64))
244+
self.assertIsInstance(li[0][1], torch.Tensor)
245+
246+
def test_statistics(self):
247+
rnd = np.random.rand(40, 50).astype(np.float16)
248+
stat = tensor_statistics(rnd)
249+
self.assertEqual(stat["stype"], "FLOAT16")
250+
rnd = np.random.rand(40, 50).astype(np.float32)
251+
stat = tensor_statistics(rnd)
252+
self.assertEqual(stat["stype"], "FLOAT")
253+
125254

126255
if __name__ == "__main__":
127256
unittest.main(verbosity=2)

_unittests/ut_helpers/test_ort_session.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@
1212
requires_onnxruntime_training,
1313
requires_cuda,
1414
)
15-
from onnx_diagnostic.helpers.onnx_helper import (
16-
from_array_extended,
17-
onnx_dtype_to_np_dtype,
18-
onnx_dtype_to_torch_dtype,
19-
)
15+
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, onnx_dtype_to_np_dtype
16+
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
2017
from onnx_diagnostic.helpers.ort_session import (
2118
InferenceSessionForNumpy,
2219
InferenceSessionForTorch,

_unittests/ut_helpers/test_torch_test_helper.py renamed to _unittests/ut_helpers/test_torch_helper.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import unittest
2+
import numpy as np
23
import ml_dtypes
34
import onnx
45
import torch
56
import transformers
67
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
78
from onnx_diagnostic.helpers import max_diff, string_type
8-
from onnx_diagnostic.helpers.torch_test_helper import (
9+
from onnx_diagnostic.helpers.torch_helper import (
910
dummy_llm,
1011
to_numpy,
1112
is_torchdynamo_exporting,
@@ -24,6 +25,8 @@
2425
make_sliding_window_cache,
2526
)
2627
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
28+
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended
29+
from onnx_diagnostic.helpers.torch_helper import to_tensor
2730

2831
TFLOAT = onnx.TensorProto.FLOAT
2932

@@ -205,7 +208,7 @@ def forward(self, x, y):
205208
else:
206209
print("output", k, v)
207210
print(string_type(restored, with_shape=True))
208-
l1, l2 = 183, 192
211+
l1, l2 = 186, 195
209212
self.assertEqual(
210213
[
211214
(f"-Model-{l2}", 0, "I"),
@@ -344,6 +347,35 @@ def forward(self, x, y=None):
344347
stat,
345348
)
346349

350+
def test_to_tensor(self):
351+
for dtype in [
352+
np.int8,
353+
np.uint8,
354+
np.int16,
355+
np.uint16,
356+
np.int32,
357+
np.uint32,
358+
np.int64,
359+
np.uint64,
360+
np.float16,
361+
np.float32,
362+
np.float64,
363+
]:
364+
with self.subTest(dtype=dtype):
365+
a = np.random.rand(4, 5).astype(dtype)
366+
proto = from_array_extended(a)
367+
b = to_array_extended(proto)
368+
self.assertEqualArray(a, b)
369+
c = to_tensor(proto)
370+
self.assertEqualArray(a, c)
371+
372+
for dtype in [torch.bfloat16]:
373+
with self.subTest(dtype=dtype):
374+
a = torch.rand((4, 5), dtype=dtype)
375+
proto = from_array_extended(a)
376+
c = to_tensor(proto)
377+
self.assertEqualArray(a, c)
378+
347379

348380
if __name__ == "__main__":
349381
unittest.main(verbosity=2)

_unittests/ut_reference/test_ort_evaluator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@
1414
ignore_warnings,
1515
requires_cuda,
1616
)
17-
from onnx_diagnostic.helpers.onnx_helper import (
18-
from_array_extended,
19-
onnx_dtype_to_torch_dtype,
20-
onnx_dtype_to_np_dtype,
21-
)
17+
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, onnx_dtype_to_np_dtype
18+
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
2219
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
2320
from onnx_diagnostic.helpers.ort_session import _InferenceSession
2421

0 commit comments

Comments
 (0)