Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion _unittests/ut_helpers/test_torch_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
replace_string_by_dynamic,
to_any,
torch_deepcopy,
torch_tensor_size,
)
from onnx_diagnostic.helpers.cache_helper import (
make_dynamic_cache,
Expand Down Expand Up @@ -204,7 +205,7 @@ def forward(self, x, y):
else:
print("output", k, v)
print(string_type(restored, with_shape=True))
l1, l2 = 182, 191
l1, l2 = 183, 192
self.assertEqual(
[
(f"-Model-{l2}", 0, "I"),
Expand Down Expand Up @@ -264,6 +265,7 @@ def test_torch_deepcopy_cache_dce(self):
c1.key_cache[0] += 1000
hash2 = string_type(at, with_shape=True, with_min_max=True)
self.assertEqual(hash1, hash2)
self.assertGreater(torch_tensor_size(cc), 1)

def test_torch_deepcopy_mamba_cache(self):
cache = make_mamba_cache(
Expand All @@ -280,6 +282,7 @@ def test_torch_deepcopy_mamba_cache(self):
cache.conv_states[0] += 1000
hash2 = string_type(at, with_shape=True, with_min_max=True)
self.assertEqual(hash1, hash2)
self.assertGreater(torch_tensor_size(cache), 1)

def test_torch_deepcopy_base_model_outputs(self):
bo = transformers.modeling_outputs.BaseModelOutput(
Expand All @@ -292,6 +295,7 @@ def test_torch_deepcopy_base_model_outputs(self):
bo.last_hidden_state[0] += 1000
hash2 = string_type(at, with_shape=True, with_min_max=True)
self.assertEqual(hash1, hash2)
self.assertGreater(torch_tensor_size(bo), 1)

def test_torch_deepcopy_sliding_windon_cache(self):
cache = make_sliding_window_cache(
Expand All @@ -308,9 +312,11 @@ def test_torch_deepcopy_sliding_windon_cache(self):
cache.key_cache[0] += 1000
hash2 = string_type(at, with_shape=True, with_min_max=True)
self.assertEqual(hash1, hash2)
self.assertGreater(torch_tensor_size(cache), 1)

def test_torch_deepcopy_none(self):
self.assertEmpty(torch_deepcopy(None))
self.assertEqual(torch_tensor_size(None), 0)

def test_model_statistics(self):
class Model(torch.nn.Module):
Expand Down
94 changes: 94 additions & 0 deletions _unittests/ut_reference/test_onnxruntime_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import unittest
import numpy as np
import onnx
import onnx.helper as oh
import torch
import onnxruntime
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator

try:
Expand Down Expand Up @@ -96,6 +99,97 @@ def false_fn(x, y):
for e, g in zip(expected, got):
self.assertEqualArray(e, g, atol=1e-5)

def test_constant_bool(self):
node = oh.make_node(
"Constant",
[],
["cbool"],
value=from_array_extended(np.array(True, dtype=np.bool_)),
)
ref = ExtendedReferenceEvaluator(node)
got = ref.run(None, {})[0]
self.assertEqual(got.dtype, np.bool_)
self.assertEqual(got, True)
ref = OnnxruntimeEvaluator(node, opsets=21)
got = ref.run(None, {})[0]
self.assertEqual(len(ref._cache), 1)
values = list(ref._cache.values())
_, sess = values[0]
got2 = sess.run(None, {})[0]
self.assertIn(got2.dtype, (torch.bool, np.bool_))
self.assertEqual(got2, True)

self.assertIn(got.dtype, (torch.bool, np.bool_))
self.assertEqual(got, True)

def test_constant_bool_array(self):
node = oh.make_node(
"Constant",
[],
["cbool"],
value=from_array_extended(np.array([True], dtype=np.bool_)),
)
ref = ExtendedReferenceEvaluator(node)
got = ref.run(None, {})[0]
self.assertEqual(got.dtype, np.bool_)
self.assertEqual(got[0], True)
ref = OnnxruntimeEvaluator(node, opsets=21)
got = ref.run(None, {})[0]
self.assertEqual(len(ref._cache), 1)
values = list(ref._cache.values())
_, sess = values[0]
got2 = sess.run(None, {})[0]
self.assertIn(got2.dtype, (torch.bool, np.bool_))
self.assertEqual(got2[0], True)

self.assertIn(got.dtype, (torch.bool, np.bool_))
self.assertEqual(got[0], True)

def test_constant_bool_input(self):
node = oh.make_model(
oh.make_graph(
[oh.make_node("Identity", ["bin"], ["bout"])],
"test",
[oh.make_tensor_value_info("bin", onnx.TensorProto.BOOL, [1])],
[oh.make_tensor_value_info("bin", onnx.TensorProto.BOOL, [1])],
),
ir_version=10,
opset_imports=[oh.make_opsetid("", 18)],
)
feeds = dict(bin=np.array([True], dtype=np.bool_))
ref = ExtendedReferenceEvaluator(node)

got = ref.run(None, feeds)[0]
self.assertEqual(got.dtype, np.bool_)
self.assertEqual(got[0], True)

ref = OnnxruntimeEvaluator(node, opsets=21)
got = ref.run(None, feeds)[0]
self.assertEqual(got.dtype, np.bool_)
self.assertEqual(got[0], True)

feeds = dict(bin=torch.tensor([True], dtype=torch.bool))
got = ref.run(None, feeds)[0]
self.assertEqual(got.dtype, torch.bool)
self.assertEqual(got[0], True)

def test_ort_eval_loop(self):
model = torch.nn.EmbeddingBag(num_embeddings=49157, embedding_dim=32, mode="sum")
a = torch.tensor([[39906, 39906]]).long()
example_args = (a,)
model_eval = model.eval()
expected = model(*example_args)

onx = to_onnx(model_eval, example_args, optimize=True)
self.assertIn("Loop", set(n.op_type for n in onx.graph.node))

ref = OnnxruntimeEvaluator(onx, verbose=10)
feeds = dict(
zip([i.name for i in onx.graph.input], [t.detach().numpy() for t in example_args])
)
got = ref.run(None, feeds)
self.assertEqualArray(expected, got[0])


if __name__ == "__main__":
unittest.main(verbosity=2)
46 changes: 36 additions & 10 deletions onnx_diagnostic/helpers/ort_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
)
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
if isinstance(sess, onnx.ModelProto):
debug_path = "_debug_onnxruntine_evaluator_failure.onnx"
debug_path = "_debug_InferenceSession_last_failure.onnx"
onnx.save(
sess,
debug_path,
Expand Down Expand Up @@ -154,6 +154,7 @@ def __init__(
)

self._torch_from_dlpack = _from_dlpack
self.sess_bool_outputs = [i.type == "tensor(bool)" for i in sess.get_outputs()]

def run(
self,
Expand All @@ -166,7 +167,19 @@ def run(
ort_outputs = self.sess._sess.run_with_ort_values(
feeds, output_names or self.output_names, self.run_options
)
return ort_outputs
return self._post_process_inplace(ort_outputs)

def _post_process_inplace(self, outputs):
for i in range(len(outputs)):
o = outputs[i]
if self.sess_bool_outputs[i]:
if isinstance(o, np.ndarray):
if o.dtype != np.bool_:
outputs[i] = o.astype(np.bool_)
else:
if o.dtype != torch.bool:
outputs[i] = o.to(torch.bool)
return outputs


class InferenceSessionForNumpy(_InferenceSession):
Expand Down Expand Up @@ -221,7 +234,7 @@ def run(
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
# sess.run does not support blfoat16
# res = self.sess.run(output_names, feeds)
return list(self.run_dlpack(output_names, feeds))
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))

def run_dlpack(
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
Expand All @@ -231,17 +244,23 @@ def run_dlpack(
feeds is a dictionary of :class:`np.ndarray`.
The output device is CPU even if the outputs are on CUDA.
"""
memory = []
new_feeds = {}
for k, v in feeds.items():
if not k:
continue
new_feeds[k] = (
ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
if isinstance(v, np.ndarray):
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
v, np_dtype_to_tensor_dtype(v.dtype)
)
if isinstance(v, np.ndarray)
else ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool)
)
elif v.dtype == torch.bool:
vi = v.detach().cpu().numpy()
memory.append(vi)
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
vi, onnx.TensorProto.BOOL
)
else:
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)

if self.nvtx:
self.torch.cuda.nvtx.range_push("run_with_ort_values")
Expand Down Expand Up @@ -421,7 +440,7 @@ def run( # type: ignore
if self.use_training_api:
inputs = [feeds[i] for i in self.input_names]
return self.run_training_api(*inputs, output_names=output_names)
return self.run_dlpack(output_names, feeds)
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))

def run_training_api(
self, *inputs, output_names: Optional[List[str]] = None
Expand Down Expand Up @@ -471,7 +490,14 @@ def run_dlpack(
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
if not v.is_contiguous():
v = v.contiguous()
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool)
if v.dtype == torch.bool:
# It does not work with dlpack
# unless onnxruntime updates the version it is using.
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
v.detach().numpy(), onnx.TensorProto.BOOL
)
else:
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
if self.nvtx:
self.torch.cuda.nvtx.range_push("run_with_ort_values")
ort_outputs = self.sess._sess.run_with_ort_values(
Expand Down
Loading
Loading