Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions _unittests/ut_reference/test_onnxruntime_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import unittest
import onnx
import torch
import onnxruntime
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator

try:
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
except ImportError:
to_onnx = None


class TestOnnxruntimeEvaluator(ExtTestCase):
def test_ort_eval_scan_cdist_add(self):

def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor):
sub = samex - x.reshape((1, -1))
sq = sub * sub
rd = torch.sqrt(sq.sum(axis=1))
# clone --> UnsupportedAliasMutationException:
# Combine_fn might be aliasing the input!
return [unused.clone(), rd]

class ScanModel(torch.nn.Module):
def forward(self, x):
z = torch.tensor([0], dtype=torch.float32)
y = x.clone()
out = torch.ops.higher_order.scan(dist, [z], [x], additional_inputs=[y])
return out[1]

x = torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32)
model = ScanModel()
expected = model(x)
onx = to_onnx(
model,
(x,),
optimize=True,
export_options=ExportOptions(decomposition_table="default", strict=False),
inline=False,
)
filename = self.get_dump_file("test_ort_eval_scan_cdist_add.onnx")
onnx.save(onx, filename)
inits = [i.name for i in onx.graph.initializer]
self.assertEqual(inits, ["c_lifted_tensor_0"])
name = onx.graph.input[0].name

sess = onnxruntime.InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
got = sess.run(None, {name: x.numpy()})[0]
self.assertEqualArray(expected, got)

ref = ExtendedReferenceEvaluator(onx)
got = ref.run(None, {name: x.numpy()})[0]
self.assertEqualArray(expected, got)

orte = OnnxruntimeEvaluator(onx)
got = orte.run(None, {name: x.numpy()})[0]
self.assertEqualArray(expected, got)

def test_ort_eval_cond(self):
import torch

class TwoInputs(torch.nn.Module):
def forward(self, x, y):
def true_fn(x, y):
return torch.sin(x), torch.cos(x) + y

def false_fn(x, y):
return torch.cos(x), torch.sin(x) + y

return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y])

x, y = torch.rand(5, 3), torch.rand(5, 3)
model = TwoInputs()
onx = to_onnx(model, (x, y), inline=False)
self.assertEqual(len(onx.functions), 2)

# ExtendedReferenceEvaluator
ref = ExtendedReferenceEvaluator(onx)
for _x in (x, -x):
expected = model(_x, y)
got = ref.run(None, {"x": _x.detach().numpy(), "y": y.detach().numpy()})
self.assertEqual(len(expected), len(got))
for e, g in zip(expected, got):
self.assertEqualArray(e, g, atol=1e-5)

# OnnxruntimeEvaluator
ref = OnnxruntimeEvaluator(onx)

for _x in (x, -x):
expected = model(_x, y)
got = ref.run(None, {"x": _x.detach().numpy(), "y": y.detach().numpy()})
self.assertEqual(len(expected), len(got))
for e, g in zip(expected, got):
self.assertEqualArray(e, g, atol=1e-5)


if __name__ == "__main__":
unittest.main(verbosity=2)
13 changes: 13 additions & 0 deletions _unittests/ut_reference/test_ort_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
onnx_dtype_to_np_dtype,
)
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
from onnx_diagnostic.helpers.ort_session import _InferenceSession

TFLOAT = TensorProto.FLOAT

Expand Down Expand Up @@ -78,6 +79,18 @@ def test_ort_eval(self):
self.assertEqualArray(expected, got, atol=1e-4)
self.assertIn("Reshape(xm, shape3) -> Z", out)

@ignore_warnings(DeprecationWarning)
def test__inference(self):
model = self._get_model()

feeds = {"X": self._range(32, 128), "Y": self._range(3, 5, 128, 64)}
ref = ExtendedReferenceEvaluator(model)
expected = ref.run(None, feeds)[0]

ort_eval = _InferenceSession(model)
got = ort_eval.run(None, feeds)[0]
self.assertEqualArray(expected, got, atol=1e-4)

@ignore_warnings(DeprecationWarning)
@requires_cuda()
@hide_stdout()
Expand Down
15 changes: 14 additions & 1 deletion onnx_diagnostic/helpers/ort_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ def __init__(

self._torch_from_dlpack = _from_dlpack

def run(
self,
output_names: Optional[List[str]],
feeds: Union[Dict[str, np.ndarray], Dict[str, ORTC.OrtValue]],
) -> Union[List[np.ndarray], List[ORTC.OrtValue]]:
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
if any(isinstance(t, np.ndarray) for t in feeds.values()):
return self.sess.run(output_names, feeds)
ort_outputs = self.sess._sess.run_with_ort_values(
feeds, output_names or self.output_names, self.run_options
)
return ort_outputs


class InferenceSessionForNumpy(_InferenceSession):
"""
Expand Down Expand Up @@ -398,7 +411,7 @@ def _ortvalues_to_torch_tensor(
self.torch.cuda.nvtx.range_pop()
return tuple(res)

def run(
def run( # type: ignore
self, output_names: Optional[List[str]], feeds: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
"""
Expand Down
2 changes: 2 additions & 0 deletions onnx_diagnostic/reference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .ops.op_quick_gelu import QuickGelu
from .ops.op_replace_zero import ReplaceZero
from .ops.op_rotary import Rotary
from .ops.op_scan import Scan
from .ops.op_scatter_elements import ScatterElements
from .ops.op_scatternd_of_shape import MaskedScatterNDOfShape, ScatterNDOfShape
from .ops.op_simplified_layer_normalization import SimplifiedLayerNormalization
Expand Down Expand Up @@ -99,6 +100,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
QuickGelu,
ReplaceZero,
Rotary,
Scan,
ScatterElements,
ScatterNDOfShape,
SimplifiedLayerNormalization,
Expand Down
65 changes: 65 additions & 0 deletions onnx_diagnostic/reference/ops/op_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
from onnx.reference.ops.op_scan import Scan as _Scan


class Scan(_Scan):

def need_context(self) -> bool:
"""Tells the runtime if this node needs the context
(all the results produced so far) as it may silently access
one of them (operator Loop).
The default answer is `False`.
"""
return True

def _run(
self,
*args,
context=None,
body=None,
num_scan_inputs=None,
scan_input_axes=None,
scan_input_directions=None,
scan_output_axes=None,
scan_output_directions=None,
attributes=None,
):
(
num_loop_state_vars,
num_scan_outputs,
output_directions,
max_dir_out,
output_axes,
max_axe_out,
state_names_in,
state_names_out,
scan_names_in,
scan_names_out,
scan_values,
states,
) = self._common_run_shape(*args)

max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]]
results = [[] for _ in scan_names_out] # type: ignore

for it in range(max_iter):
inputs = context.copy()
inputs.update(dict(zip(state_names_in, states)))
inputs.update({name: value[it] for name, value in zip(scan_names_in, scan_values)})

try:
outputs_list = self._run_body(inputs) # type: ignore
except TypeError as e:
raise TypeError(
f"Unable to call 'run' for type '{type(self.body)}'." # type: ignore
) from e

outputs = dict(zip(self.output_names, outputs_list))
states = [outputs[name] for name in state_names_out]
for i, name in enumerate(scan_names_out):
results[i].append(np.expand_dims(outputs[name], axis=0))

for res in results:
conc = np.vstack(res)
states.append(conc)
return self._check_and_fix_outputs(tuple(states))
Loading
Loading