Skip to content

Commit 1a00215

Browse files
authored
Fixes Scan in ExtendedReferenceEvaluator, OnnxruntimeEvaluator (#95)
* Fix examples with scan * mypy * mypy * eval * add test for test * adds missing method * mypy
1 parent 3470344 commit 1a00215

File tree

8 files changed

+333
-39
lines changed

8 files changed

+333
-39
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import unittest
2+
import onnx
3+
import torch
4+
import onnxruntime
5+
from onnx_diagnostic.ext_test_case import ExtTestCase
6+
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator
7+
8+
try:
9+
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
10+
except ImportError:
11+
to_onnx = None
12+
13+
14+
class TestOnnxruntimeEvaluator(ExtTestCase):
15+
def test_ort_eval_scan_cdist_add(self):
16+
17+
def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor):
18+
sub = samex - x.reshape((1, -1))
19+
sq = sub * sub
20+
rd = torch.sqrt(sq.sum(axis=1))
21+
# clone --> UnsupportedAliasMutationException:
22+
# Combine_fn might be aliasing the input!
23+
return [unused.clone(), rd]
24+
25+
class ScanModel(torch.nn.Module):
26+
def forward(self, x):
27+
z = torch.tensor([0], dtype=torch.float32)
28+
y = x.clone()
29+
out = torch.ops.higher_order.scan(dist, [z], [x], additional_inputs=[y])
30+
return out[1]
31+
32+
x = torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32)
33+
model = ScanModel()
34+
expected = model(x)
35+
onx = to_onnx(
36+
model,
37+
(x,),
38+
optimize=True,
39+
export_options=ExportOptions(decomposition_table="default", strict=False),
40+
inline=False,
41+
)
42+
filename = self.get_dump_file("test_ort_eval_scan_cdist_add.onnx")
43+
onnx.save(onx, filename)
44+
inits = [i.name for i in onx.graph.initializer]
45+
self.assertEqual(inits, ["c_lifted_tensor_0"])
46+
name = onx.graph.input[0].name
47+
48+
sess = onnxruntime.InferenceSession(
49+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
50+
)
51+
got = sess.run(None, {name: x.numpy()})[0]
52+
self.assertEqualArray(expected, got)
53+
54+
ref = ExtendedReferenceEvaluator(onx)
55+
got = ref.run(None, {name: x.numpy()})[0]
56+
self.assertEqualArray(expected, got)
57+
58+
orte = OnnxruntimeEvaluator(onx)
59+
got = orte.run(None, {name: x.numpy()})[0]
60+
self.assertEqualArray(expected, got)
61+
62+
def test_ort_eval_cond(self):
63+
import torch
64+
65+
class TwoInputs(torch.nn.Module):
66+
def forward(self, x, y):
67+
def true_fn(x, y):
68+
return torch.sin(x), torch.cos(x) + y
69+
70+
def false_fn(x, y):
71+
return torch.cos(x), torch.sin(x) + y
72+
73+
return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y])
74+
75+
x, y = torch.rand(5, 3), torch.rand(5, 3)
76+
model = TwoInputs()
77+
onx = to_onnx(model, (x, y), inline=False)
78+
self.assertEqual(len(onx.functions), 2)
79+
80+
# ExtendedReferenceEvaluator
81+
ref = ExtendedReferenceEvaluator(onx)
82+
for _x in (x, -x):
83+
expected = model(_x, y)
84+
got = ref.run(None, {"x": _x.detach().numpy(), "y": y.detach().numpy()})
85+
self.assertEqual(len(expected), len(got))
86+
for e, g in zip(expected, got):
87+
self.assertEqualArray(e, g, atol=1e-5)
88+
89+
# OnnxruntimeEvaluator
90+
ref = OnnxruntimeEvaluator(onx)
91+
92+
for _x in (x, -x):
93+
expected = model(_x, y)
94+
got = ref.run(None, {"x": _x.detach().numpy(), "y": y.detach().numpy()})
95+
self.assertEqual(len(expected), len(got))
96+
for e, g in zip(expected, got):
97+
self.assertEqualArray(e, g, atol=1e-5)
98+
99+
100+
if __name__ == "__main__":
101+
unittest.main(verbosity=2)

_unittests/ut_reference/test_ort_evaluator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
onnx_dtype_to_np_dtype,
2121
)
2222
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
23+
from onnx_diagnostic.helpers.ort_session import _InferenceSession
2324

2425
TFLOAT = TensorProto.FLOAT
2526

@@ -78,6 +79,18 @@ def test_ort_eval(self):
7879
self.assertEqualArray(expected, got, atol=1e-4)
7980
self.assertIn("Reshape(xm, shape3) -> Z", out)
8081

82+
@ignore_warnings(DeprecationWarning)
83+
def test__inference(self):
84+
model = self._get_model()
85+
86+
feeds = {"X": self._range(32, 128), "Y": self._range(3, 5, 128, 64)}
87+
ref = ExtendedReferenceEvaluator(model)
88+
expected = ref.run(None, feeds)[0]
89+
90+
ort_eval = _InferenceSession(model)
91+
got = ort_eval.run(None, feeds)[0]
92+
self.assertEqualArray(expected, got, atol=1e-4)
93+
8194
@ignore_warnings(DeprecationWarning)
8295
@requires_cuda()
8396
@hide_stdout()

onnx_diagnostic/helpers/ort_session.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,19 @@ def __init__(
155155

156156
self._torch_from_dlpack = _from_dlpack
157157

158+
def run(
159+
self,
160+
output_names: Optional[List[str]],
161+
feeds: Union[Dict[str, np.ndarray], Dict[str, ORTC.OrtValue]],
162+
) -> Union[List[np.ndarray], List[ORTC.OrtValue]]:
163+
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
164+
if any(isinstance(t, np.ndarray) for t in feeds.values()):
165+
return self.sess.run(output_names, feeds)
166+
ort_outputs = self.sess._sess.run_with_ort_values(
167+
feeds, output_names or self.output_names, self.run_options
168+
)
169+
return ort_outputs
170+
158171

159172
class InferenceSessionForNumpy(_InferenceSession):
160173
"""
@@ -398,7 +411,7 @@ def _ortvalues_to_torch_tensor(
398411
self.torch.cuda.nvtx.range_pop()
399412
return tuple(res)
400413

401-
def run(
414+
def run( # type: ignore
402415
self, output_names: Optional[List[str]], feeds: Dict[str, torch.Tensor]
403416
) -> Tuple[torch.Tensor, ...]:
404417
"""

onnx_diagnostic/reference/evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .ops.op_quick_gelu import QuickGelu
3434
from .ops.op_replace_zero import ReplaceZero
3535
from .ops.op_rotary import Rotary
36+
from .ops.op_scan import Scan
3637
from .ops.op_scatter_elements import ScatterElements
3738
from .ops.op_scatternd_of_shape import MaskedScatterNDOfShape, ScatterNDOfShape
3839
from .ops.op_simplified_layer_normalization import SimplifiedLayerNormalization
@@ -99,6 +100,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
99100
QuickGelu,
100101
ReplaceZero,
101102
Rotary,
103+
Scan,
102104
ScatterElements,
103105
ScatterNDOfShape,
104106
SimplifiedLayerNormalization,
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import numpy as np
2+
from onnx.reference.ops.op_scan import Scan as _Scan
3+
4+
5+
class Scan(_Scan):
6+
7+
def need_context(self) -> bool:
8+
"""Tells the runtime if this node needs the context
9+
(all the results produced so far) as it may silently access
10+
one of them (operator Loop).
11+
The default answer is `False`.
12+
"""
13+
return True
14+
15+
def _run(
16+
self,
17+
*args,
18+
context=None,
19+
body=None,
20+
num_scan_inputs=None,
21+
scan_input_axes=None,
22+
scan_input_directions=None,
23+
scan_output_axes=None,
24+
scan_output_directions=None,
25+
attributes=None,
26+
):
27+
(
28+
num_loop_state_vars,
29+
num_scan_outputs,
30+
output_directions,
31+
max_dir_out,
32+
output_axes,
33+
max_axe_out,
34+
state_names_in,
35+
state_names_out,
36+
scan_names_in,
37+
scan_names_out,
38+
scan_values,
39+
states,
40+
) = self._common_run_shape(*args)
41+
42+
max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]]
43+
results = [[] for _ in scan_names_out] # type: ignore
44+
45+
for it in range(max_iter):
46+
inputs = context.copy()
47+
inputs.update(dict(zip(state_names_in, states)))
48+
inputs.update({name: value[it] for name, value in zip(scan_names_in, scan_values)})
49+
50+
try:
51+
outputs_list = self._run_body(inputs) # type: ignore
52+
except TypeError as e:
53+
raise TypeError(
54+
f"Unable to call 'run' for type '{type(self.body)}'." # type: ignore
55+
) from e
56+
57+
outputs = dict(zip(self.output_names, outputs_list))
58+
states = [outputs[name] for name in state_names_out]
59+
for i, name in enumerate(scan_names_out):
60+
results[i].append(np.expand_dims(outputs[name], axis=0))
61+
62+
for res in results:
63+
conc = np.vstack(res)
64+
states.append(conc)
65+
return self._check_and_fix_outputs(tuple(states))

0 commit comments

Comments
 (0)