Skip to content

Commit 31a9234

Browse files
committed
fix bfloat16
1 parent d365f3d commit 31a9234

File tree

5 files changed

+201
-32
lines changed

5 files changed

+201
-32
lines changed

_doc/examples/plot_failing_onnxruntime_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474

7575
ref = OnnxruntimeEvaluator(model, verbose=10)
7676
feeds = dict(
77-
X=torch.rand((3, 4), dtype=torch.blofat16), Y=torch.rand((3, 4), dtype=torch.blofat16)
77+
X=torch.rand((3, 4), dtype=torch.bfloat16), Y=torch.rand((3, 4), dtype=torch.bfloat16)
7878
)
7979
try:
8080
ref.run(None, feeds)

_unittests/ut_xrun_doc/test_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
string_signature,
2020
make_hash,
2121
onnx_dtype_to_torch_dtype,
22+
onnx_dtype_to_np_dtype,
2223
np_dtype_to_tensor_dtype,
2324
torch_dtype_to_onnx_dtype,
2425
from_array_extended,
@@ -213,6 +214,7 @@ def test_size_type_onnx(self):
213214
"FLOAT8E4M3FNUZ",
214215
}:
215216
onnx_dtype_to_torch_dtype(i)
217+
onnx_dtype_to_np_dtype(i)
216218

217219
def test_size_type_numpy(self):
218220
for dt in {

_unittests/ut_xrun_doc/test_ort_session.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
2-
from typing import Dict, Optional, Tuple
2+
from typing import Any, Dict, Optional, Tuple
33
import numpy as np
4+
import ml_dtypes
45
import onnx
56
import onnx.helper as oh
67
import torch
@@ -11,6 +12,11 @@
1112
requires_onnxruntime_training,
1213
requires_cuda,
1314
)
15+
from onnx_diagnostic.helpers import (
16+
from_array_extended,
17+
onnx_dtype_to_np_dtype,
18+
onnx_dtype_to_torch_dtype,
19+
)
1420
from onnx_diagnostic.ort_session import (
1521
InferenceSessionForNumpy,
1622
InferenceSessionForTorch,
@@ -232,6 +238,66 @@ def test_investigate_onnxruntime_issue_callable_str(self):
232238
onnx_to_session="cpu_session",
233239
)
234240

241+
@classmethod
242+
def _get_model_init(cls, itype) -> Tuple[onnx.ModelProto, Dict[str, Any], Tuple[Any, ...]]:
243+
dtype = onnx_dtype_to_np_dtype(itype)
244+
ttype = onnx_dtype_to_torch_dtype(itype)
245+
cst = np.arange(6).astype(dtype)
246+
model = oh.make_model(
247+
oh.make_graph(
248+
[
249+
oh.make_node("IsNaN", ["x"], ["xi"]),
250+
oh.make_node("IsNaN", ["y"], ["yi"]),
251+
oh.make_node("Cast", ["xi"], ["xii"], to=onnx.TensorProto.INT64),
252+
oh.make_node("Cast", ["yi"], ["yii"], to=onnx.TensorProto.INT64),
253+
oh.make_node("Add", ["xii", "yii"], ["gggg"]),
254+
oh.make_node("Cast", ["gggg"], ["final"], to=itype),
255+
],
256+
"dummy",
257+
[oh.make_tensor_value_info("x", itype, [None, None])],
258+
[oh.make_tensor_value_info("final", itype, [None, None])],
259+
[from_array_extended(cst, name="y")],
260+
),
261+
opset_imports=[oh.make_opsetid("", 20)],
262+
ir_version=10,
263+
)
264+
onnx.checker.check_model(model)
265+
feeds = {"x": cls._range(5, 6).to(ttype)}
266+
expected = torch.isnan(feeds["x"]).to(int) + torch.isnan(
267+
torch.from_numpy(cst.astype(float))
268+
).to(int)
269+
return (model, feeds, (expected.to(ttype),))
270+
271+
def test_init_numpy_afloat32(self):
272+
model, feeds, expected = self._get_model_init(onnx.TensorProto.FLOAT)
273+
wrap = InferenceSessionForNumpy(model, providers="cpu", graph_optimization_level=False)
274+
got = wrap.run(None, {k: v.numpy() for k, v in feeds.items()})
275+
self.assertIsInstance(got[0], np.ndarray)
276+
self.assertEqualArray(expected[0], got[0])
277+
278+
def test_init_numpy_bfloat16(self):
279+
model, feeds, expected = self._get_model_init(onnx.TensorProto.BFLOAT16)
280+
wrap = InferenceSessionForNumpy(model, providers="cpu", graph_optimization_level=False)
281+
got = wrap.run(
282+
None, {k: v.to(float).numpy().astype(ml_dtypes.bfloat16) for k, v in feeds.items()}
283+
)
284+
self.assertIsInstance(got[0], np.ndarray)
285+
self.assertEqualArray(expected[0], got[0])
286+
287+
def test_init_torch_afloat32(self):
288+
model, feeds, expected = self._get_model_init(onnx.TensorProto.FLOAT)
289+
wrap = InferenceSessionForTorch(model, providers="cpu", graph_optimization_level=False)
290+
got = wrap.run(None, feeds)
291+
self.assertIsInstance(got[0], torch.Tensor)
292+
self.assertEqualArray(expected[0], got[0])
293+
294+
def test_init_torch_bfloat16(self):
295+
model, feeds, expected = self._get_model_init(onnx.TensorProto.BFLOAT16)
296+
wrap = InferenceSessionForTorch(model, providers="cpu", graph_optimization_level=False)
297+
got = wrap.run(None, feeds)
298+
self.assertIsInstance(got[0], torch.Tensor)
299+
self.assertEqualArray(expected[0], got[0])
300+
235301

236302
if __name__ == "__main__":
237303
unittest.main(verbosity=2)

onnx_diagnostic/helpers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,51 @@ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
889889
)
890890

891891

892+
def onnx_dtype_to_np_dtype(itype: int) -> Any:
893+
"""
894+
Converts an onnx type into a to numpy dtype.
895+
That includes :epkg:`ml_dtypes` dtypes.
896+
897+
:param to: onnx dtype
898+
:return: numpy dtype
899+
"""
900+
if itype == TensorProto.FLOAT:
901+
return np.float32
902+
if itype == TensorProto.FLOAT16:
903+
return np.float16
904+
if itype == TensorProto.BFLOAT16:
905+
import ml_dtypes
906+
907+
return ml_dtypes.bfloat16
908+
if itype == TensorProto.DOUBLE:
909+
return np.float64
910+
if itype == TensorProto.INT32:
911+
return np.int32
912+
if itype == TensorProto.INT64:
913+
return np.int64
914+
if itype == TensorProto.UINT32:
915+
return np.uint32
916+
if itype == TensorProto.UINT64:
917+
return np.uint64
918+
if itype == TensorProto.BOOL:
919+
return np.bool
920+
if itype == TensorProto.INT16:
921+
return np.int16
922+
if itype == TensorProto.UINT16:
923+
return np.uint16
924+
if itype == TensorProto.INT8:
925+
return np.int16
926+
if itype == TensorProto.UINT8:
927+
return np.uint16
928+
if itype == TensorProto.COMPLEX64:
929+
return np.complex64
930+
if itype == TensorProto.COMPLEX128:
931+
return np.complex128
932+
raise NotImplementedError(
933+
f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type."
934+
)
935+
936+
892937
def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
893938
"""
894939
Converts a torch dtype into a onnx element type.

onnx_diagnostic/ort_session.py

Lines changed: 86 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
from torch._C import _from_dlpack
77
import onnxruntime
88
from onnxruntime.capi import _pybind_state as ORTC
9+
from .helpers import (
10+
torch_dtype_to_onnx_dtype,
11+
onnx_dtype_to_np_dtype,
12+
np_dtype_to_tensor_dtype,
13+
onnx_dtype_name,
14+
size_type,
15+
)
916

1017
DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
1118

@@ -48,7 +55,14 @@ def __init__(
4855
):
4956
# onnxruntime is importing when needed as it takes a
5057
# couple of seconds if it contains CUDA EP.
58+
can_use_training_api = True
5159
if isinstance(sess, (onnx.ModelProto, str)):
60+
if isinstance(sess, onnx.ModelProto):
61+
for i in sess.graph.initializer:
62+
if i.data_type >= onnx.TensorProto.BFLOAT16:
63+
# Cannot use training api as it relies too much on numpy.
64+
can_use_training_api = False
65+
break
5266
assert session_options is None or (
5367
providers is None
5468
and graph_optimization_level is None
@@ -113,7 +127,7 @@ def __init__(
113127
if log_verbosity_level is not None:
114128
self.run_options.log_verbosity_level = log_verbosity_level
115129

116-
self.use_training_api = (
130+
self.use_training_api = can_use_training_api and (
117131
self.has_onnxruntime_training() if use_training_api is None else use_training_api
118132
)
119133

@@ -176,7 +190,76 @@ def run(
176190
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
177191
) -> List[npt.ArrayLike]:
178192
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
179-
return self.sess.run(output_names, feeds)
193+
if any(
194+
(np_dtype_to_tensor_dtype(v.dtype) >= onnx.TensorProto.BFLOAT16)
195+
for v in feeds.values()
196+
):
197+
# bfloat16 not supported by onnxruntime
198+
return self.run_dlpack(output_names, feeds)
199+
if self.nvtx:
200+
self.torch.cuda.nvtx.range_push("run")
201+
res = self.sess.run(output_names, feeds)
202+
if self.nvtx:
203+
self.torch.cuda.nvtx.range_pop()
204+
return res
205+
206+
def run_dlpack(
207+
self, output_names: Optional[List[str]], feeds: Dict[str, np.ndarray]
208+
) -> Tuple[torch.Tensor, ...]:
209+
"""
210+
Same as :meth:`onnxruntime.InferenceSession.run` except that
211+
feeds is a dictionary of :class:`np.ndarray`.
212+
The output device is CPU even if the outputs are on CUDA.
213+
"""
214+
new_feeds = {}
215+
for k, v in feeds.items():
216+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
217+
v, np_dtype_to_tensor_dtype(v.dtype)
218+
)
219+
if self.nvtx:
220+
self.torch.cuda.nvtx.range_push("run_with_ort_values")
221+
ort_outputs = self.sess._sess.run_with_ort_values(
222+
new_feeds, output_names or self.output_names, self.run_options
223+
)
224+
if self.nvtx:
225+
self.torch.cuda.nvtx.range_pop()
226+
pth_outputs = self._ortvalues_to_numpy_tensor(ort_outputs)
227+
return pth_outputs
228+
229+
def _ortvalues_to_numpy_tensor(
230+
self,
231+
ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
232+
) -> Tuple[np.ndarray, ...]:
233+
if len(ortvalues) == 0:
234+
return tuple()
235+
236+
if self.nvtx:
237+
self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor")
238+
res = []
239+
for i in range(len(ortvalues)):
240+
if not ortvalues[i].has_value():
241+
res.append(None)
242+
continue
243+
244+
el_type = ortvalues[i].element_type()
245+
if el_type < onnx.TensorProto.BFLOAT16:
246+
res.append(np.from_dlpack(ortvalues[i]))
247+
continue
248+
249+
# no easy conversion, let's use torch
250+
tch = torch.from_dlpack(ortvalues[i].to_dlpack())
251+
size = size_type(el_type)
252+
assert size == 2, f"Not implemented for type {onnx_dtype_name(el_type)}"
253+
it = torch.uint16
254+
itch = tch.view(it)
255+
npt = itch.numpy()
256+
257+
dtype = onnx_dtype_to_np_dtype(el_type)
258+
res.append(npt.view(dtype))
259+
260+
if self.nvtx:
261+
self.torch.cuda.nvtx.range_pop()
262+
return tuple(res)
180263

181264

182265
class InferenceSessionForTorch(_InferenceSession):
@@ -225,33 +308,6 @@ def __init__(
225308
use_training_api=use_training_api,
226309
)
227310

228-
self.TORCH_DTYPE_TO_ONNX_DTYPE = {
229-
torch.float16: onnx.TensorProto.FLOAT16,
230-
torch.bfloat16: onnx.TensorProto.BFLOAT16,
231-
torch.float32: onnx.TensorProto.FLOAT,
232-
torch.float64: onnx.TensorProto.DOUBLE,
233-
torch.uint32: onnx.TensorProto.UINT32,
234-
torch.uint16: onnx.TensorProto.UINT16,
235-
torch.uint8: onnx.TensorProto.UINT8,
236-
torch.int8: onnx.TensorProto.INT8,
237-
torch.int16: onnx.TensorProto.INT16,
238-
torch.int32: onnx.TensorProto.INT32,
239-
torch.int64: onnx.TensorProto.INT64,
240-
torch.bool: onnx.TensorProto.BOOL,
241-
}
242-
243-
self.TORCH_DTYPE_TO_NUMPY_DTYPE = {
244-
torch.float16: np.float16,
245-
torch.float32: np.float32,
246-
torch.float64: np.float64,
247-
torch.uint8: np.uint8,
248-
torch.int8: np.int8,
249-
torch.int16: np.int16,
250-
torch.int32: np.int32,
251-
torch.int64: np.int64,
252-
torch.bool: np.bool_,
253-
}
254-
255311
def _get_ortvalues_from_torch_tensors(
256312
self, tensors: Tuple[torch.Tensor, ...], n_outputs: int
257313
) -> Tuple[ORTC.OrtValueVector, List[onnxruntime.OrtDevice]]:
@@ -269,7 +325,7 @@ def _get_ortvalues_from_torch_tensors(
269325
new_tensors = []
270326
for tensor in tensors:
271327
assert isinstance(tensor, self.torch.Tensor), f"Unexpected type {type(tensor)}"
272-
dtypes.append(self.TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype])
328+
dtypes.append(onnx_dtype_to_np_dtype(torch_dtype_to_onnx_dtype(tensor.dtype)))
273329
shapes.append(tensor.size())
274330
data_ptrs.append(tensor.data_ptr())
275331
d = tensor.get_device()

0 commit comments

Comments
 (0)