Skip to content

Commit 536c658

Browse files
committed
doc
1 parent 31a9234 commit 536c658

File tree

8 files changed

+136
-31
lines changed

8 files changed

+136
-31
lines changed

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186
"onnxrt backend": "https://pytorch.org/docs/stable/onnx_dynamo_onnxruntime_backend.html",
187187
"onnxruntime": "https://onnxruntime.ai/",
188188
"onnxruntime-training": "https://onnxruntime.ai/docs/get-started/training-on-device.html",
189+
"onnxruntime kernels": "https://onnxruntime.ai/docs/reference/operators/OperatorKernels.html",
189190
"onnx-array-api": "https://sdpython.github.io/doc/onnx-array-api/dev/",
190191
"onnx-diagnostic": "https://sdpython.github.io/doc/onnx-diagnostic/dev/",
191192
"onnx-extended": "https://sdpython.github.io/doc/onnx-extended/dev/",

_doc/examples/plot_failing_onnxruntime_evaluator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import onnx.helper as oh
2929
import torch
3030
import onnxruntime
31+
from onnx_diagnostic.ext_test_case import has_cuda
3132
from onnx_diagnostic.helpers import from_array_extended
3233
from onnx_diagnostic.reference import OnnxruntimeEvaluator
3334

@@ -81,6 +82,21 @@
8182
except Exception as e:
8283
print("ERROR", type(e), e)
8384

85+
86+
# %%
87+
# :epkg:`onnxruntime` may not support bfloat16 on CPU.
88+
# See :epkg:`onnxruntime kernels`.
89+
90+
if has_cuda():
91+
ref = OnnxruntimeEvaluator(model, providers="cuda", verbose=10)
92+
feeds = dict(
93+
X=torch.rand((3, 4), dtype=torch.bfloat16), Y=torch.rand((3, 4), dtype=torch.bfloat16)
94+
)
95+
try:
96+
ref.run(None, feeds)
97+
except Exception as e:
98+
print("ERROR", type(e), e)
99+
84100
# %%
85101
# We can see it run until it reaches `Cast` and stops.
86102
# The error message is not always obvious to interpret.

_unittests/ut_reference/test_array_tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
import numpy as np
33
from onnx import TensorProto
44
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
5-
from onnx.reference.op_run import to_array_extended
65
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
7-
from onnx_diagnostic.helpers import from_array_extended
6+
from onnx_diagnostic.helpers import from_array_extended, to_array_extended
87
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
98

109

_unittests/ut_reference/test_ort_evaluator.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
2-
from typing import Optional
2+
from typing import Any, Dict, Optional, Tuple
33
import numpy as np
4+
import ml_dtypes
45
from onnx import ModelProto, TensorProto
56
from onnx.checker import check_model
67
import onnx.helper as oh
@@ -12,6 +13,11 @@
1213
ignore_warnings,
1314
requires_cuda,
1415
)
16+
from onnx_diagnostic.helpers import (
17+
from_array_extended,
18+
onnx_dtype_to_torch_dtype,
19+
onnx_dtype_to_np_dtype,
20+
)
1521
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
1622

1723
TFLOAT = TensorProto.FLOAT
@@ -163,6 +169,85 @@ def test_local_function(self):
163169
got = ort_eval.run(None, feeds)
164170
self.assertEqualArray(expected[0], got[0])
165171

172+
@classmethod
173+
def _trange(cls, *shape, bias: Optional[float] = None):
174+
n = np.prod(shape)
175+
x = np.arange(n).astype(np.float32) / n
176+
if bias:
177+
x = x + bias
178+
return torch.from_numpy(x.reshape(tuple(shape)).astype(np.float32))
179+
180+
@classmethod
181+
def _get_model_init(cls, itype) -> Tuple[ModelProto, Dict[str, Any], Tuple[Any, ...]]:
182+
dtype = onnx_dtype_to_np_dtype(itype)
183+
ttype = onnx_dtype_to_torch_dtype(itype)
184+
cst = np.arange(6).astype(dtype)
185+
model = oh.make_model(
186+
oh.make_graph(
187+
[
188+
oh.make_node("IsNaN", ["x"], ["xi"]),
189+
oh.make_node("IsNaN", ["y"], ["yi"]),
190+
oh.make_node("Cast", ["xi"], ["xii"], to=TensorProto.INT64),
191+
oh.make_node("Cast", ["yi"], ["yii"], to=TensorProto.INT64),
192+
oh.make_node("Add", ["xii", "yii"], ["gggg"]),
193+
oh.make_node("Cast", ["gggg"], ["final"], to=itype),
194+
],
195+
"dummy",
196+
[oh.make_tensor_value_info("x", itype, [None, None])],
197+
[oh.make_tensor_value_info("final", itype, [None, None])],
198+
[from_array_extended(cst, name="y")],
199+
),
200+
opset_imports=[oh.make_opsetid("", 20)],
201+
ir_version=10,
202+
)
203+
feeds = {"x": cls._trange(5, 6).to(ttype)}
204+
expected = torch.isnan(feeds["x"]).to(int) + torch.isnan(
205+
torch.from_numpy(cst.astype(float))
206+
).to(int)
207+
return (model, feeds, (expected.to(ttype),))
208+
209+
@hide_stdout()
210+
def test_init_numpy_afloat32(self):
211+
model, feeds, expected = self._get_model_init(TensorProto.FLOAT)
212+
wrap = OnnxruntimeEvaluator(
213+
model, providers="cpu", graph_optimization_level=False, verbose=10
214+
)
215+
got = wrap.run(None, {k: v.numpy() for k, v in feeds.items()})
216+
self.assertIsInstance(got[0], np.ndarray)
217+
self.assertEqualArray(expected[0], got[0])
218+
219+
@hide_stdout()
220+
def test_init_numpy_bfloat16(self):
221+
model, feeds, expected = self._get_model_init(TensorProto.BFLOAT16)
222+
wrap = OnnxruntimeEvaluator(
223+
model, providers="cpu", graph_optimization_level=False, verbose=10
224+
)
225+
got = wrap.run(
226+
None, {k: v.to(float).numpy().astype(ml_dtypes.bfloat16) for k, v in feeds.items()}
227+
)
228+
self.assertIsInstance(got[0], np.ndarray)
229+
self.assertEqualArray(expected[0], got[0])
230+
231+
@hide_stdout()
232+
def test_init_torch_afloat32(self):
233+
model, feeds, expected = self._get_model_init(TensorProto.FLOAT)
234+
wrap = OnnxruntimeEvaluator(
235+
model, providers="cpu", graph_optimization_level=False, verbose=10
236+
)
237+
got = wrap.run(None, feeds)
238+
self.assertIsInstance(got[0], torch.Tensor)
239+
self.assertEqualArray(expected[0], got[0])
240+
241+
@hide_stdout()
242+
def test_init_torch_bfloat16(self):
243+
model, feeds, expected = self._get_model_init(TensorProto.BFLOAT16)
244+
wrap = OnnxruntimeEvaluator(
245+
model, providers="cpu", graph_optimization_level=False, verbose=10
246+
)
247+
got = wrap.run(None, feeds)
248+
self.assertIsInstance(got[0], torch.Tensor)
249+
self.assertEqualArray(expected[0], got[0])
250+
166251

167252
if __name__ == "__main__":
168253
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
np_dtype_to_tensor_dtype,
2424
torch_dtype_to_onnx_dtype,
2525
from_array_extended,
26+
to_array_extended,
2627
convert_endian,
2728
from_array_ml_dtypes,
2829
dtype_to_tensor_dtype,
@@ -250,16 +251,20 @@ def test_from_array(self):
250251
t = np.random.rand(4, 3).astype(dt)
251252
proto = from_array_extended(t)
252253
self.assertIsInstance(proto, onnx.TensorProto)
253-
convert_endian(proto)
254254
dtype_to_tensor_dtype(dt)
255+
arr = to_array_extended(proto)
256+
self.assertEqualArray(t, arr)
257+
convert_endian(proto)
255258

256259
def test_from_array_ml_dtypes(self):
257260
for dt in {
258261
ml_dtypes.bfloat16,
259262
}:
260263
t = np.random.rand(4, 3).astype(dt)
261-
from_array_ml_dtypes(t)
264+
proto = from_array_ml_dtypes(t)
262265
from_array_extended(t)
266+
arr = to_array_extended(proto)
267+
self.assertEqualArray(t, arr)
263268

264269
def test_size_type_mldtypes(self):
265270
for dt in {

onnx_diagnostic/helpers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
np_dtype_to_tensor_dtype as onnx_np_dtype_to_tensor_dtype,
2323
tensor_dtype_to_np_dtype as onnx_tensor_dtype_to_np_dtype,
2424
)
25-
from onnx.numpy_helper import from_array as onnx_from_array
26-
from onnx.reference.op_run import to_array_extended
25+
from onnx.numpy_helper import from_array as onnx_from_array, to_array
2726

2827

2928
def size_type(dtype: Any) -> int:
@@ -845,6 +844,16 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te
845844
return t
846845

847846

847+
def to_array_extended(proto: TensorProto) -> npt.ArrayLike:
848+
"""Converts :class:`onnx.TensorProto` into a numpy array."""
849+
arr = to_array(proto)
850+
if proto.data_type >= onnx.TensorProto.BFLOAT16:
851+
# Types not supported by numpy
852+
ml_dtypes = onnx_dtype_to_np_dtype(proto.data_type)
853+
return arr.view(ml_dtypes)
854+
return arr
855+
856+
848857
def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
849858
"""
850859
Converts an onnx type into a torch dtype.

onnx_diagnostic/ort_session.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,9 @@ def run(
190190
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
191191
) -> List[npt.ArrayLike]:
192192
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
193-
if any(
194-
(np_dtype_to_tensor_dtype(v.dtype) >= onnx.TensorProto.BFLOAT16)
195-
for v in feeds.values()
196-
):
197-
# bfloat16 not supported by onnxruntime
198-
return self.run_dlpack(output_names, feeds)
199-
if self.nvtx:
200-
self.torch.cuda.nvtx.range_push("run")
201-
res = self.sess.run(output_names, feeds)
202-
if self.nvtx:
203-
self.torch.cuda.nvtx.range_pop()
204-
return res
193+
# sess.run does not support blfoat16
194+
# res = self.sess.run(output_names, feeds)
195+
return self.run_dlpack(output_names, feeds)
205196

206197
def run_dlpack(
207198
self, output_names: Optional[List[str]], feeds: Dict[str, np.ndarray]
@@ -213,8 +204,12 @@ def run_dlpack(
213204
"""
214205
new_feeds = {}
215206
for k, v in feeds.items():
216-
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
217-
v, np_dtype_to_tensor_dtype(v.dtype)
207+
new_feeds[k] = (
208+
ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
209+
v, np_dtype_to_tensor_dtype(v.dtype)
210+
)
211+
if isinstance(v, np.ndarray)
212+
else ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool)
218213
)
219214
if self.nvtx:
220215
self.torch.cuda.nvtx.range_push("run_with_ort_values")

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
load,
1212
)
1313
from onnx.defs import onnx_opset_version
14-
from onnx.numpy_helper import to_array
1514
import onnxruntime
16-
from ..helpers import pretty_onnx, dtype_to_tensor_dtype, string_type
15+
from ..helpers import pretty_onnx, dtype_to_tensor_dtype, string_type, to_array_extended
1716
from ..ort_session import InferenceSessionForTorch, InferenceSessionForNumpy, _InferenceSession
1817

1918
PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
@@ -54,7 +53,7 @@ def __init__(
5453
log_verbosity_level: Optional[int] = None,
5554
optimized_model_filepath: Optional[str] = None,
5655
disable_aot_function_inlining: Optional[bool] = None,
57-
use_training_api: Optional[bool] = None,
56+
use_training_api: bool = False,
5857
verbose: int = 0,
5958
local_functions: Optional[
6059
Dict[Tuple[str, str], Union[Proto, "OnnxruntimeEvaluator"]]
@@ -103,7 +102,7 @@ def __init__(
103102
)
104103
)
105104
self.rt_inits_ = (
106-
{init.name: to_array(init) for init in self.proto.graph.initializer}
105+
{init.name: to_array_extended(init) for init in self.proto.graph.initializer}
107106
if hasattr(self.proto, "graph")
108107
else {}
109108
)
@@ -144,12 +143,8 @@ def output_names(self) -> List[str]:
144143
def _log_arg(self, a: Any) -> Any:
145144
if isinstance(a, (str, int, float)):
146145
return a
147-
if hasattr(a, "detach"):
148-
device = f"D{a.get_device()}:"
149-
a = a.detach().cpu().numpy()
150-
else:
151-
device = ""
152-
if isinstance(a, np.ndarray):
146+
device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
147+
if hasattr(a, "shape"):
153148
if self.verbose < 4: # noqa: PLR2004
154149
return f"{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]"
155150
elements = a.ravel().tolist()

0 commit comments

Comments
 (0)