Skip to content

Commit 58ba91e

Browse files
committed
Fix examples with scan
1 parent 3470344 commit 58ba91e

File tree

6 files changed

+264
-41
lines changed

6 files changed

+264
-41
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import unittest
2+
import onnx
3+
import torch
4+
import onnxruntime
5+
from onnx_diagnostic.ext_test_case import ExtTestCase
6+
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator
7+
8+
try:
9+
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
10+
except ImportError:
11+
to_onnx = None
12+
13+
14+
class TestOnnxruntimeEvaluator(ExtTestCase):
15+
def test_ort_eval_scan_cdist_add(self):
16+
17+
def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor):
18+
sub = samex - x.reshape((1, -1))
19+
sq = sub * sub
20+
rd = torch.sqrt(sq.sum(axis=1))
21+
# clone --> UnsupportedAliasMutationException:
22+
# Combine_fn might be aliasing the input!
23+
return [unused.clone(), rd]
24+
25+
class ScanModel(torch.nn.Module):
26+
def forward(self, x):
27+
z = torch.tensor([0], dtype=torch.float32)
28+
y = x.clone()
29+
out = torch.ops.higher_order.scan(dist, [z], [x], additional_inputs=[y])
30+
return out[1]
31+
32+
x = torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32)
33+
model = ScanModel()
34+
expected = model(x)
35+
onx = to_onnx(
36+
model,
37+
(x,),
38+
optimize=True,
39+
export_options=ExportOptions(decomposition_table="default", strict=False),
40+
inline=False,
41+
)
42+
filename = self.get_dump_file("test_ort_eval_scan_cdist_add.onnx")
43+
onnx.save(onx, filename)
44+
inits = [i.name for i in onx.graph.initializer]
45+
self.assertEqual(inits, ["c_lifted_tensor_0"])
46+
name = onx.graph.input[0].name
47+
48+
sess = onnxruntime.InferenceSession(
49+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
50+
)
51+
got = sess.run(None, {name: x.numpy()})[0]
52+
self.assertEqualArray(expected, got)
53+
54+
ref = ExtendedReferenceEvaluator(onx)
55+
got = ref.run(None, {name: x.numpy()})[0]
56+
self.assertEqualArray(expected, got)
57+
58+
orte = OnnxruntimeEvaluator(onx)
59+
got = orte.run(None, {name: x.numpy()})[0]
60+
self.assertEqualArray(expected, got)
61+
62+
63+
if __name__ == "__main__":
64+
unittest.main(verbosity=2)

onnx_diagnostic/reference/evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .ops.op_quick_gelu import QuickGelu
3434
from .ops.op_replace_zero import ReplaceZero
3535
from .ops.op_rotary import Rotary
36+
from .ops.op_scan import Scan
3637
from .ops.op_scatter_elements import ScatterElements
3738
from .ops.op_scatternd_of_shape import MaskedScatterNDOfShape, ScatterNDOfShape
3839
from .ops.op_simplified_layer_normalization import SimplifiedLayerNormalization
@@ -99,6 +100,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
99100
QuickGelu,
100101
ReplaceZero,
101102
Rotary,
103+
Scan,
102104
ScatterElements,
103105
ScatterNDOfShape,
104106
SimplifiedLayerNormalization,
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import numpy as np
2+
from onnx.reference.ops.op_scan import Scan as _Scan
3+
4+
5+
class Scan(_Scan):
6+
7+
def need_context(self) -> bool:
8+
"""Tells the runtime if this node needs the context
9+
(all the results produced so far) as it may silently access
10+
one of them (operator Loop).
11+
The default answer is `False`.
12+
"""
13+
return True
14+
15+
def _run(
16+
self,
17+
*args,
18+
context=None,
19+
body=None,
20+
num_scan_inputs=None,
21+
scan_input_axes=None,
22+
scan_input_directions=None,
23+
scan_output_axes=None,
24+
scan_output_directions=None,
25+
attributes=None,
26+
):
27+
(
28+
num_loop_state_vars,
29+
num_scan_outputs,
30+
output_directions,
31+
max_dir_out,
32+
output_axes,
33+
max_axe_out,
34+
state_names_in,
35+
state_names_out,
36+
scan_names_in,
37+
scan_names_out,
38+
scan_values,
39+
states,
40+
) = self._common_run_shape(*args)
41+
42+
max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]]
43+
results = [[] for _ in scan_names_out] # type: ignore
44+
45+
for it in range(max_iter):
46+
inputs = context.copy()
47+
inputs.update(dict(zip(state_names_in, states)))
48+
inputs.update({name: value[it] for name, value in zip(scan_names_in, scan_values)})
49+
50+
try:
51+
outputs_list = self._run_body(inputs) # type: ignore
52+
except TypeError as e:
53+
raise TypeError(
54+
f"Unable to call 'run' for type '{type(self.body)}'." # type: ignore
55+
) from e
56+
57+
outputs = dict(zip(self.output_names, outputs_list))
58+
states = [outputs[name] for name in state_names_out]
59+
for i, name in enumerate(scan_names_out):
60+
results[i].append(np.expand_dims(outputs[name], axis=0))
61+
62+
for res in results:
63+
conc = np.vstack(res)
64+
states.append(conc)
65+
return self._check_and_fix_outputs(tuple(states))

0 commit comments

Comments
 (0)