Skip to content

Commit e151311

Browse files
authored
Add verbosity to steal_forward (#97)
* Add verbosity * better onnxruntimeeval * fix bug in onnxruntime * fix bool inputs * name * mypy * add garbe * fix issues * fix issues
1 parent 08b1cdf commit e151311

File tree

5 files changed

+340
-38
lines changed

5 files changed

+340
-38
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
replace_string_by_dynamic,
1616
to_any,
1717
torch_deepcopy,
18+
torch_tensor_size,
1819
)
1920
from onnx_diagnostic.helpers.cache_helper import (
2021
make_dynamic_cache,
@@ -204,7 +205,7 @@ def forward(self, x, y):
204205
else:
205206
print("output", k, v)
206207
print(string_type(restored, with_shape=True))
207-
l1, l2 = 182, 191
208+
l1, l2 = 183, 192
208209
self.assertEqual(
209210
[
210211
(f"-Model-{l2}", 0, "I"),
@@ -264,6 +265,7 @@ def test_torch_deepcopy_cache_dce(self):
264265
c1.key_cache[0] += 1000
265266
hash2 = string_type(at, with_shape=True, with_min_max=True)
266267
self.assertEqual(hash1, hash2)
268+
self.assertGreater(torch_tensor_size(cc), 1)
267269

268270
def test_torch_deepcopy_mamba_cache(self):
269271
cache = make_mamba_cache(
@@ -280,6 +282,7 @@ def test_torch_deepcopy_mamba_cache(self):
280282
cache.conv_states[0] += 1000
281283
hash2 = string_type(at, with_shape=True, with_min_max=True)
282284
self.assertEqual(hash1, hash2)
285+
self.assertGreater(torch_tensor_size(cache), 1)
283286

284287
def test_torch_deepcopy_base_model_outputs(self):
285288
bo = transformers.modeling_outputs.BaseModelOutput(
@@ -292,6 +295,7 @@ def test_torch_deepcopy_base_model_outputs(self):
292295
bo.last_hidden_state[0] += 1000
293296
hash2 = string_type(at, with_shape=True, with_min_max=True)
294297
self.assertEqual(hash1, hash2)
298+
self.assertGreater(torch_tensor_size(bo), 1)
295299

296300
def test_torch_deepcopy_sliding_windon_cache(self):
297301
cache = make_sliding_window_cache(
@@ -308,9 +312,11 @@ def test_torch_deepcopy_sliding_windon_cache(self):
308312
cache.key_cache[0] += 1000
309313
hash2 = string_type(at, with_shape=True, with_min_max=True)
310314
self.assertEqual(hash1, hash2)
315+
self.assertGreater(torch_tensor_size(cache), 1)
311316

312317
def test_torch_deepcopy_none(self):
313318
self.assertEmpty(torch_deepcopy(None))
319+
self.assertEqual(torch_tensor_size(None), 0)
314320

315321
def test_model_statistics(self):
316322
class Model(torch.nn.Module):

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import unittest
2+
import numpy as np
23
import onnx
4+
import onnx.helper as oh
35
import torch
46
import onnxruntime
57
from onnx_diagnostic.ext_test_case import ExtTestCase
8+
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
69
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator
710

811
try:
@@ -96,6 +99,97 @@ def false_fn(x, y):
9699
for e, g in zip(expected, got):
97100
self.assertEqualArray(e, g, atol=1e-5)
98101

102+
def test_constant_bool(self):
103+
node = oh.make_node(
104+
"Constant",
105+
[],
106+
["cbool"],
107+
value=from_array_extended(np.array(True, dtype=np.bool_)),
108+
)
109+
ref = ExtendedReferenceEvaluator(node)
110+
got = ref.run(None, {})[0]
111+
self.assertEqual(got.dtype, np.bool_)
112+
self.assertEqual(got, True)
113+
ref = OnnxruntimeEvaluator(node, opsets=21)
114+
got = ref.run(None, {})[0]
115+
self.assertEqual(len(ref._cache), 1)
116+
values = list(ref._cache.values())
117+
_, sess = values[0]
118+
got2 = sess.run(None, {})[0]
119+
self.assertIn(got2.dtype, (torch.bool, np.bool_))
120+
self.assertEqual(got2, True)
121+
122+
self.assertIn(got.dtype, (torch.bool, np.bool_))
123+
self.assertEqual(got, True)
124+
125+
def test_constant_bool_array(self):
126+
node = oh.make_node(
127+
"Constant",
128+
[],
129+
["cbool"],
130+
value=from_array_extended(np.array([True], dtype=np.bool_)),
131+
)
132+
ref = ExtendedReferenceEvaluator(node)
133+
got = ref.run(None, {})[0]
134+
self.assertEqual(got.dtype, np.bool_)
135+
self.assertEqual(got[0], True)
136+
ref = OnnxruntimeEvaluator(node, opsets=21)
137+
got = ref.run(None, {})[0]
138+
self.assertEqual(len(ref._cache), 1)
139+
values = list(ref._cache.values())
140+
_, sess = values[0]
141+
got2 = sess.run(None, {})[0]
142+
self.assertIn(got2.dtype, (torch.bool, np.bool_))
143+
self.assertEqual(got2[0], True)
144+
145+
self.assertIn(got.dtype, (torch.bool, np.bool_))
146+
self.assertEqual(got[0], True)
147+
148+
def test_constant_bool_input(self):
149+
node = oh.make_model(
150+
oh.make_graph(
151+
[oh.make_node("Identity", ["bin"], ["bout"])],
152+
"test",
153+
[oh.make_tensor_value_info("bin", onnx.TensorProto.BOOL, [1])],
154+
[oh.make_tensor_value_info("bin", onnx.TensorProto.BOOL, [1])],
155+
),
156+
ir_version=10,
157+
opset_imports=[oh.make_opsetid("", 18)],
158+
)
159+
feeds = dict(bin=np.array([True], dtype=np.bool_))
160+
ref = ExtendedReferenceEvaluator(node)
161+
162+
got = ref.run(None, feeds)[0]
163+
self.assertEqual(got.dtype, np.bool_)
164+
self.assertEqual(got[0], True)
165+
166+
ref = OnnxruntimeEvaluator(node, opsets=21)
167+
got = ref.run(None, feeds)[0]
168+
self.assertEqual(got.dtype, np.bool_)
169+
self.assertEqual(got[0], True)
170+
171+
feeds = dict(bin=torch.tensor([True], dtype=torch.bool))
172+
got = ref.run(None, feeds)[0]
173+
self.assertEqual(got.dtype, torch.bool)
174+
self.assertEqual(got[0], True)
175+
176+
def test_ort_eval_loop(self):
177+
model = torch.nn.EmbeddingBag(num_embeddings=49157, embedding_dim=32, mode="sum")
178+
a = torch.tensor([[39906, 39906]]).long()
179+
example_args = (a,)
180+
model_eval = model.eval()
181+
expected = model(*example_args)
182+
183+
onx = to_onnx(model_eval, example_args, optimize=True)
184+
self.assertIn("Loop", set(n.op_type for n in onx.graph.node))
185+
186+
ref = OnnxruntimeEvaluator(onx, verbose=10)
187+
feeds = dict(
188+
zip([i.name for i in onx.graph.input], [t.detach().numpy() for t in example_args])
189+
)
190+
got = ref.run(None, feeds)
191+
self.assertEqualArray(expected, got[0])
192+
99193

100194
if __name__ == "__main__":
101195
unittest.main(verbosity=2)

onnx_diagnostic/helpers/ort_session.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(
109109
)
110110
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
111111
if isinstance(sess, onnx.ModelProto):
112-
debug_path = "_debug_onnxruntine_evaluator_failure.onnx"
112+
debug_path = "_debug_InferenceSession_last_failure.onnx"
113113
onnx.save(
114114
sess,
115115
debug_path,
@@ -154,6 +154,7 @@ def __init__(
154154
)
155155

156156
self._torch_from_dlpack = _from_dlpack
157+
self.sess_bool_outputs = [i.type == "tensor(bool)" for i in sess.get_outputs()]
157158

158159
def run(
159160
self,
@@ -166,7 +167,19 @@ def run(
166167
ort_outputs = self.sess._sess.run_with_ort_values(
167168
feeds, output_names or self.output_names, self.run_options
168169
)
169-
return ort_outputs
170+
return self._post_process_inplace(ort_outputs)
171+
172+
def _post_process_inplace(self, outputs):
173+
for i in range(len(outputs)):
174+
o = outputs[i]
175+
if self.sess_bool_outputs[i]:
176+
if isinstance(o, np.ndarray):
177+
if o.dtype != np.bool_:
178+
outputs[i] = o.astype(np.bool_)
179+
else:
180+
if o.dtype != torch.bool:
181+
outputs[i] = o.to(torch.bool)
182+
return outputs
170183

171184

172185
class InferenceSessionForNumpy(_InferenceSession):
@@ -221,7 +234,7 @@ def run(
221234
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
222235
# sess.run does not support blfoat16
223236
# res = self.sess.run(output_names, feeds)
224-
return list(self.run_dlpack(output_names, feeds))
237+
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
225238

226239
def run_dlpack(
227240
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
@@ -231,17 +244,23 @@ def run_dlpack(
231244
feeds is a dictionary of :class:`np.ndarray`.
232245
The output device is CPU even if the outputs are on CUDA.
233246
"""
247+
memory = []
234248
new_feeds = {}
235249
for k, v in feeds.items():
236250
if not k:
237251
continue
238-
new_feeds[k] = (
239-
ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
252+
if isinstance(v, np.ndarray):
253+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
240254
v, np_dtype_to_tensor_dtype(v.dtype)
241255
)
242-
if isinstance(v, np.ndarray)
243-
else ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool)
244-
)
256+
elif v.dtype == torch.bool:
257+
vi = v.detach().cpu().numpy()
258+
memory.append(vi)
259+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
260+
vi, onnx.TensorProto.BOOL
261+
)
262+
else:
263+
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
245264

246265
if self.nvtx:
247266
self.torch.cuda.nvtx.range_push("run_with_ort_values")
@@ -421,7 +440,7 @@ def run( # type: ignore
421440
if self.use_training_api:
422441
inputs = [feeds[i] for i in self.input_names]
423442
return self.run_training_api(*inputs, output_names=output_names)
424-
return self.run_dlpack(output_names, feeds)
443+
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
425444

426445
def run_training_api(
427446
self, *inputs, output_names: Optional[List[str]] = None
@@ -471,7 +490,14 @@ def run_dlpack(
471490
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
472491
if not v.is_contiguous():
473492
v = v.contiguous()
474-
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool)
493+
if v.dtype == torch.bool:
494+
# It does not work with dlpack
495+
# unless onnxruntime updates the version it is using.
496+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
497+
v.detach().numpy(), onnx.TensorProto.BOOL
498+
)
499+
else:
500+
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
475501
if self.nvtx:
476502
self.torch.cuda.nvtx.range_push("run_with_ort_values")
477503
ort_outputs = self.sess._sess.run_with_ort_values(

0 commit comments

Comments
 (0)