Skip to content

Commit 225a3ac

Browse files
committed
final
1 parent e8db2d8 commit 225a3ac

File tree

4 files changed

+73
-9
lines changed

4 files changed

+73
-9
lines changed

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
# Let's make sure the ONNX model produces the same outputs.
131131
# It takes flatten inputs.
132132

133-
feeds = make_feeds(onx, copy.deepcopy(inputs), use_numpy=True)
133+
feeds = make_feeds(onx, copy.deepcopy(inputs), use_numpy=True, copy=True)
134134

135135
print(f"torch inputs: {string_type(inputs)}")
136136
print(f"onxrt inputs: {string_type(feeds)}")

_unittests/ut_helpers/test_ort_session_tinyllm.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import copy
22
import unittest
3+
import numpy as np
4+
import onnx
35
import torch
46
import onnxruntime
7+
from onnxruntime.capi import _pybind_state as ORTC
58
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
69
from onnx_diagnostic.helpers import max_diff
710
from onnx_diagnostic.helpers.ort_session import (
@@ -12,10 +15,61 @@
1215
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
1316
from onnx_diagnostic.torch_models.llms import get_tiny_llm
1417
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
18+
from onnx_diagnostic.helpers.onnx_helper import np_dtype_to_tensor_dtype
1519

1620

1721
class TestOrtSessionTinyLLM(ExtTestCase):
1822

23+
def test_ort_value(self):
24+
val = np.array([30, 31, 32], dtype=np.int64)
25+
ort = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(val, onnx.TensorProto.INT64)
26+
self.assertEqual(np_dtype_to_tensor_dtype(val.dtype), onnx.TensorProto.INT64)
27+
val2 = ort.numpy()
28+
self.assertEqualArray(val, val2)
29+
ort = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
30+
val, np_dtype_to_tensor_dtype(val.dtype)
31+
)
32+
val2 = ort.numpy()
33+
self.assertEqualArray(val, val2)
34+
35+
def test_ort_value_py(self):
36+
data = get_tiny_llm()
37+
inputs = data["inputs"]
38+
feeds = make_feeds(
39+
["input_ids", "attention_mask", "position_ids", "key0", "value0"],
40+
inputs,
41+
use_numpy=True,
42+
copy=True,
43+
)
44+
new_feeds = {}
45+
for k, v in feeds.items():
46+
new_feeds[k] = onnxruntime.OrtValue.ortvalue_from_numpy_with_onnx_type(
47+
v, np_dtype_to_tensor_dtype(v.dtype)
48+
)
49+
other_feeds = {k: v.numpy() for k, v in new_feeds.items()}
50+
self.assertEqualAny(feeds, other_feeds)
51+
52+
def test_ort_value_more(self):
53+
data = get_tiny_llm()
54+
inputs = data["inputs"]
55+
feeds = make_feeds(
56+
["input_ids", "attention_mask", "position_ids", "key0", "value0"],
57+
inputs,
58+
use_numpy=True,
59+
copy=True,
60+
)
61+
feeds = {
62+
k: feeds[k].copy()
63+
for k in ["input_ids", "attention_mask", "key0", "value0", "position_ids"]
64+
}
65+
new_feeds = {}
66+
for k, v in feeds.items():
67+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
68+
v, np_dtype_to_tensor_dtype(v.dtype)
69+
)
70+
other_feeds = {k: v.numpy() for k, v in new_feeds.items()}
71+
self.assertEqualAny(feeds, other_feeds)
72+
1973
@ignore_warnings((UserWarning, DeprecationWarning, FutureWarning))
2074
@hide_stdout()
2175
def test_check_allruntimes_on_tiny_llm(self):
@@ -30,7 +84,7 @@ def test_check_allruntimes_on_tiny_llm(self):
3084

3185
proto = ep.model_proto
3286
self.dump_onnx("test_check_allruntimes_on_tiny_llm.onnx", proto)
33-
feeds = make_feeds(proto, inputs, use_numpy=True)
87+
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
3488
sess = onnxruntime.InferenceSession(
3589
proto.SerializeToString(), providers=["CPUExecutionProvider"]
3690
)
@@ -45,10 +99,10 @@ def test_check_allruntimes_on_tiny_llm(self):
4599
self.assertEqualArray(got[0], all_outputs["linear_7"])
46100

47101
sess = InferenceSessionForNumpy(proto)
48-
got = sess.run(None, feeds, expected=all_outputs)
102+
got = sess.run(None, feeds)
49103
self.assertLess(max_diff(expected, got, flatten=True)["abs"], 1e-5)
50104

51-
feeds = make_feeds(proto, inputs)
105+
feeds = make_feeds(proto, inputs, copy=True)
52106
sess = InferenceSessionForTorch(proto)
53107
got = sess.run(None, feeds)
54108
self.assertLess(max_diff(expected, got, flatten=True)["abs"], 1e-5)

onnx_diagnostic/ext_test_case.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def assert_onnx_disc(
10661066
if verbose:
10671067
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
10681068
if use_ort:
1069-
feeds = make_feeds(proto, inputs, use_numpy=True)
1069+
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
10701070
if verbose:
10711071
print(f"[{vname}] feeds {string_type(feeds, **kws)}")
10721072
import onnxruntime
@@ -1076,7 +1076,7 @@ def assert_onnx_disc(
10761076
)
10771077
got = sess.run(None, feeds)
10781078
else:
1079-
feeds = make_feeds(proto, inputs)
1079+
feeds = make_feeds(proto, inputs, copy=True)
10801080
if verbose:
10811081
print(f"[{vname}] feeds {string_type(feeds, **kws)}")
10821082
sess = InferenceSessionForTorch(proto, **kwargs)

onnx_diagnostic/helpers/ort_session.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,20 @@
1919

2020

2121
def make_feeds(
22-
proto: onnx.ModelProto, inputs: Any, use_numpy: bool = False
22+
proto: Union[onnx.ModelProto, List[str]],
23+
inputs: Any,
24+
use_numpy: bool = False,
25+
copy: bool = False,
2326
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
2427
"""
2528
Serializes the inputs to produce feeds expected
2629
by :class:`onnxruntime.InferenceSession`.
2730
28-
:param proto: onnx model
31+
:param proto: onnx model or list of names
2932
:param inputs: any kind of inputs
3033
:param use_numpy: if True, converts torch tensors into numpy arrays
34+
:param copy: a copy is made, this should be the case if the inputs is ingested
35+
by ``OrtValue``
3136
:return: feeds dictionary
3237
"""
3338
flat = flatten_object(inputs, drop_keys=True)
@@ -42,7 +47,11 @@ def make_feeds(
4247
)
4348
if use_numpy:
4449
flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat]
45-
names = [i.name for i in proto.graph.input]
50+
names = (
51+
[i.name for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) else proto
52+
)
53+
if copy:
54+
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
4655
return dict(zip(names, flat))
4756

4857

@@ -242,6 +251,7 @@ def run_dlpack(
242251
if isinstance(v, np.ndarray)
243252
else ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool)
244253
)
254+
245255
if self.nvtx:
246256
self.torch.cuda.nvtx.range_push("run_with_ort_values")
247257
ort_outputs = self.sess._sess.run_with_ort_values(

0 commit comments

Comments
 (0)