diff --git a/_unittests/ut_reference/test_onnxruntime_evaluator.py b/_unittests/ut_reference/test_onnxruntime_evaluator.py new file mode 100644 index 00000000..132009ca --- /dev/null +++ b/_unittests/ut_reference/test_onnxruntime_evaluator.py @@ -0,0 +1,101 @@ +import unittest +import onnx +import torch +import onnxruntime +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator + +try: + from experimental_experiment.torch_interpreter import to_onnx, ExportOptions +except ImportError: + to_onnx = None + + +class TestOnnxruntimeEvaluator(ExtTestCase): + def test_ort_eval_scan_cdist_add(self): + + def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor): + sub = samex - x.reshape((1, -1)) + sq = sub * sub + rd = torch.sqrt(sq.sum(axis=1)) + # clone --> UnsupportedAliasMutationException: + # Combine_fn might be aliasing the input! + return [unused.clone(), rd] + + class ScanModel(torch.nn.Module): + def forward(self, x): + z = torch.tensor([0], dtype=torch.float32) + y = x.clone() + out = torch.ops.higher_order.scan(dist, [z], [x], additional_inputs=[y]) + return out[1] + + x = torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32) + model = ScanModel() + expected = model(x) + onx = to_onnx( + model, + (x,), + optimize=True, + export_options=ExportOptions(decomposition_table="default", strict=False), + inline=False, + ) + filename = self.get_dump_file("test_ort_eval_scan_cdist_add.onnx") + onnx.save(onx, filename) + inits = [i.name for i in onx.graph.initializer] + self.assertEqual(inits, ["c_lifted_tensor_0"]) + name = onx.graph.input[0].name + + sess = onnxruntime.InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = sess.run(None, {name: x.numpy()})[0] + self.assertEqualArray(expected, got) + + ref = ExtendedReferenceEvaluator(onx) + got = ref.run(None, {name: x.numpy()})[0] + self.assertEqualArray(expected, got) + + orte = OnnxruntimeEvaluator(onx) + got = orte.run(None, {name: x.numpy()})[0] + self.assertEqualArray(expected, got) + + def test_ort_eval_cond(self): + import torch + + class TwoInputs(torch.nn.Module): + def forward(self, x, y): + def true_fn(x, y): + return torch.sin(x), torch.cos(x) + y + + def false_fn(x, y): + return torch.cos(x), torch.sin(x) + y + + return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y]) + + x, y = torch.rand(5, 3), torch.rand(5, 3) + model = TwoInputs() + onx = to_onnx(model, (x, y), inline=False) + self.assertEqual(len(onx.functions), 2) + + # ExtendedReferenceEvaluator + ref = ExtendedReferenceEvaluator(onx) + for _x in (x, -x): + expected = model(_x, y) + got = ref.run(None, {"x": _x.detach().numpy(), "y": y.detach().numpy()}) + self.assertEqual(len(expected), len(got)) + for e, g in zip(expected, got): + self.assertEqualArray(e, g, atol=1e-5) + + # OnnxruntimeEvaluator + ref = OnnxruntimeEvaluator(onx) + + for _x in (x, -x): + expected = model(_x, y) + got = ref.run(None, {"x": _x.detach().numpy(), "y": y.detach().numpy()}) + self.assertEqual(len(expected), len(got)) + for e, g in zip(expected, got): + self.assertEqualArray(e, g, atol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_reference/test_ort_evaluator.py b/_unittests/ut_reference/test_ort_evaluator.py index 3f102fad..37ca827d 100644 --- a/_unittests/ut_reference/test_ort_evaluator.py +++ b/_unittests/ut_reference/test_ort_evaluator.py @@ -20,6 +20,7 @@ onnx_dtype_to_np_dtype, ) from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator +from onnx_diagnostic.helpers.ort_session import _InferenceSession TFLOAT = TensorProto.FLOAT @@ -78,6 +79,18 @@ def test_ort_eval(self): self.assertEqualArray(expected, got, atol=1e-4) self.assertIn("Reshape(xm, shape3) -> Z", out) + @ignore_warnings(DeprecationWarning) + def test__inference(self): + model = self._get_model() + + feeds = {"X": self._range(32, 128), "Y": self._range(3, 5, 128, 64)} + ref = ExtendedReferenceEvaluator(model) + expected = ref.run(None, feeds)[0] + + ort_eval = _InferenceSession(model) + got = ort_eval.run(None, feeds)[0] + self.assertEqualArray(expected, got, atol=1e-4) + @ignore_warnings(DeprecationWarning) @requires_cuda() @hide_stdout() diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index 54a7874d..a56fabac 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -155,6 +155,19 @@ def __init__( self._torch_from_dlpack = _from_dlpack + def run( + self, + output_names: Optional[List[str]], + feeds: Union[Dict[str, np.ndarray], Dict[str, ORTC.OrtValue]], + ) -> Union[List[np.ndarray], List[ORTC.OrtValue]]: + """Calls :meth:`onnxruntime.InferenceSession.run`.""" + if any(isinstance(t, np.ndarray) for t in feeds.values()): + return self.sess.run(output_names, feeds) + ort_outputs = self.sess._sess.run_with_ort_values( + feeds, output_names or self.output_names, self.run_options + ) + return ort_outputs + class InferenceSessionForNumpy(_InferenceSession): """ @@ -398,7 +411,7 @@ def _ortvalues_to_torch_tensor( self.torch.cuda.nvtx.range_pop() return tuple(res) - def run( + def run( # type: ignore self, output_names: Optional[List[str]], feeds: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: """ diff --git a/onnx_diagnostic/reference/evaluator.py b/onnx_diagnostic/reference/evaluator.py index 431087a0..787f6fc4 100644 --- a/onnx_diagnostic/reference/evaluator.py +++ b/onnx_diagnostic/reference/evaluator.py @@ -33,6 +33,7 @@ from .ops.op_quick_gelu import QuickGelu from .ops.op_replace_zero import ReplaceZero from .ops.op_rotary import Rotary +from .ops.op_scan import Scan from .ops.op_scatter_elements import ScatterElements from .ops.op_scatternd_of_shape import MaskedScatterNDOfShape, ScatterNDOfShape from .ops.op_simplified_layer_normalization import SimplifiedLayerNormalization @@ -99,6 +100,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): QuickGelu, ReplaceZero, Rotary, + Scan, ScatterElements, ScatterNDOfShape, SimplifiedLayerNormalization, diff --git a/onnx_diagnostic/reference/ops/op_scan.py b/onnx_diagnostic/reference/ops/op_scan.py new file mode 100644 index 00000000..bcf80966 --- /dev/null +++ b/onnx_diagnostic/reference/ops/op_scan.py @@ -0,0 +1,65 @@ +import numpy as np +from onnx.reference.ops.op_scan import Scan as _Scan + + +class Scan(_Scan): + + def need_context(self) -> bool: + """Tells the runtime if this node needs the context + (all the results produced so far) as it may silently access + one of them (operator Loop). + The default answer is `False`. + """ + return True + + def _run( + self, + *args, + context=None, + body=None, + num_scan_inputs=None, + scan_input_axes=None, + scan_input_directions=None, + scan_output_axes=None, + scan_output_directions=None, + attributes=None, + ): + ( + num_loop_state_vars, + num_scan_outputs, + output_directions, + max_dir_out, + output_axes, + max_axe_out, + state_names_in, + state_names_out, + scan_names_in, + scan_names_out, + scan_values, + states, + ) = self._common_run_shape(*args) + + max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]] + results = [[] for _ in scan_names_out] # type: ignore + + for it in range(max_iter): + inputs = context.copy() + inputs.update(dict(zip(state_names_in, states))) + inputs.update({name: value[it] for name, value in zip(scan_names_in, scan_values)}) + + try: + outputs_list = self._run_body(inputs) # type: ignore + except TypeError as e: + raise TypeError( + f"Unable to call 'run' for type '{type(self.body)}'." # type: ignore + ) from e + + outputs = dict(zip(self.output_names, outputs_list)) + states = [outputs[name] for name in state_names_out] + for i, name in enumerate(scan_names_out): + results[i].append(np.expand_dims(outputs[name], axis=0)) + + for res in results: + conc = np.vstack(res) + states.append(conc) + return self._check_and_fix_outputs(tuple(states)) diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index 6148e3f5..c683c81a 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -44,6 +44,7 @@ class OnnxruntimeEvaluator: :param local_functions: additional local function :param ir_version: ir version to use when unknown :param opsets: opsets to use when unknown + :param whole: if True, do not split node by node """ def __init__( @@ -65,6 +66,7 @@ def __init__( ] = None, ir_version: int = 10, opsets: Optional[Union[int, Dict[str, int]]] = None, + whole: bool = False, ): if isinstance(proto, str): self.proto: Proto = load(proto) @@ -97,22 +99,31 @@ def __init__( use_training_api=use_training_api, ) - self.nodes = ( - [self.proto] - if isinstance(self.proto, NodeProto) - else ( - list( - self.proto.graph.node if hasattr(self.proto, "graph") else self.proto.node + self.verbose = verbose + self.sess_: Optional[_InferenceSession] = None + if whole: + self.nodes: Optional[List[NodeProto]] = None + self.rt_inits_: Optional[Dict[str, Any]] = None + self.rt_nodes_: Optional[List[NodeProto]] = None + else: + self.nodes = ( + [self.proto] + if isinstance(self.proto, NodeProto) + else ( + list( + self.proto.graph.node + if hasattr(self.proto, "graph") + else self.proto.node + ) ) ) - ) - self.rt_inits_ = ( - {init.name: to_array_extended(init) for init in self.proto.graph.initializer} - if hasattr(self.proto, "graph") - else {} - ) - self.rt_nodes_ = self.nodes.copy() - self.verbose = verbose + self.rt_inits_ = ( + {init.name: to_array_extended(init) for init in self.proto.graph.initializer} + if hasattr(self.proto, "graph") + else {} + ) + self.rt_nodes_ = self.nodes.copy() + self.local_functions: Dict[Tuple[str, str], "OnnxruntimeEvaluator"] = ( # noqa: UP037 {(f.domain, f.name): self.__class__(f) for f in self.proto.functions} if hasattr(self.proto, "functions") @@ -124,7 +135,11 @@ def __init__( @property def input_names(self) -> List[str]: "Returns input names." + assert self.proto, "self.proto is empty" if isinstance(self.proto, NodeProto): + assert isinstance( + self.nodes, list + ), f"Unexpected type {type(self.nodes)} for self.nodes" return self.nodes[0].input return [ getattr(o, "name", o) @@ -136,7 +151,11 @@ def input_names(self) -> List[str]: @property def output_names(self) -> List[str]: "Returns output names." + assert self.proto, "self.proto is empty" if isinstance(self.proto, NodeProto): + assert isinstance( + self.nodes, list + ), f"Unexpected type {type(self.nodes)} for self.nodes" return self.nodes[0].output return [ getattr(o, "name", o) @@ -201,28 +220,38 @@ def run( :return: outputs, as a list if return_all is False, as a dictionary if return_all is True """ + if self.rt_nodes_ is None: + # runs a whole + if self.sess_ is None: + assert self.proto, "self.proto is empty" + _, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values())) + assert self.sess_, "mypy not happy" + return self.sess_.run(outputs, feed_inputs) if outputs is None: outputs = self.output_names - results: Dict[str, Any] = self.rt_inits_.copy() + results: Dict[str, Any] = (self.rt_inits_ or {}).copy() - for k, v in self.rt_inits_.items(): + for k, v in results.items(): self._log(2, " +C %s: %s", k, v) for k, v in feed_inputs.items(): + assert not isinstance(v, str), f"Unexpected type str for {k!r}" self._log(2, " +I %s: %s", k, v) results[k] = v - for node in self.rt_nodes_: + for node in self.rt_nodes_ or []: self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output) for i in node.input: if i != "" and i not in results: raise RuntimeError( f"Unable to find input {i!r} in known results {sorted(results)}, " - f"self.rt_inits_ has {sorted(self.rt_inits_)}, " + f"self.rt_inits_ has {sorted((self.rt_inits_ or {}))}, " f"feed_inputs has {sorted(feed_inputs)}." ) inputs = [(results[i] if i != "" else None) for i in node.input] if node.op_type == "If" and node.domain == "": outputs = self._run_if(node, inputs, results) + elif node.op_type == "Scan" and node.domain == "": + outputs = self._run_scan(node, inputs, results) elif self._is_local_function(node): outputs = self._run_local(node, inputs, results) else: @@ -274,20 +303,24 @@ def _make_model_proto( return onx def _get_sess( - self, node: NodeProto, inputs: List[Any] + self, node: Union[ModelProto, NodeProto], inputs: List[Any] ) -> Tuple[ModelProto, _InferenceSession]: - unique_names = set() - vinputs = [] - for i, it in zip(node.input, inputs): - if i == "" or i in unique_names: - continue - unique_names.add(i) - value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) - vinputs.append(value) + if isinstance(node, ModelProto): + onx = node + else: + assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node" + unique_names = set() + vinputs = [] + for i, it in zip(node.input, inputs): + if i == "" or i in unique_names: + continue + unique_names.add(i) + value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) + vinputs.append(value) - # no need to run shape inference - voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output] - onx = self._make_model_proto([node], vinputs, voutputs) + # no need to run shape inference + voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output] + onx = self._make_model_proto([node], vinputs, voutputs) cls = ( InferenceSessionForNumpy @@ -307,9 +340,9 @@ def _get_sess( ) from e return onx, sess - def _get_sess_if( - self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any] - ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]: + def _get_sess_init_subgraph( + self, node: NodeProto, inputs: List[Any], context: Dict[str, Any] + ) -> List[Any]: unique_names = set() vinputs = [] for i, it in zip(node.input, inputs): @@ -324,14 +357,27 @@ def _get_sess_if( unique_names.add(i) value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape) vinputs.append(value) + return vinputs + + def _get_sess_if( + self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any] + ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]: + vinputs = self._get_sess_init_subgraph(node, inputs, context) + g = None for att in node.attribute: if att.name == branch: g = att.g + assert g, f"Missing attribute {branch!r}" voutputs = g.output - onx = self._make_model_proto(g.node, vinputs, voutputs) + identities = [ + oh.make_node("Identity", [iname], [ginput.name]) + for iname, ginput in zip(node.input, g.input) + ] + + onx = self._make_model_proto([*identities, *g.node], vinputs, voutputs) sess = OnnxruntimeEvaluator( onx, local_functions=self.local_functions, @@ -378,7 +424,7 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L def _run_if( self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] ) -> List[Any]: - """Runs a node if.""" + """Runs a node If.""" feeds = dict(zip(node.input, inputs)) feeds.update(results) if feeds[node.input[0]]: @@ -397,6 +443,59 @@ def _run_if( assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs + def _get_sess_scan( + self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any] + ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]: + vinputs = self._get_sess_init_subgraph(node, inputs, context) + + g = None + for att in node.attribute: + if att.name == branch: + g = att.g + assert g, f"Missing attribute {branch!r}" + + voutputs = [] + for name, goutput in zip(node.output, g.output): + b = goutput.SerializeToString() + v = ValueInfoProto() + v.ParseFromString(b) + v.name = name + voutputs.append(v) + + # identities = [] + # for iname, ginput in zip(node.input, g.input): + # identities.append(oh.make_node("Identity", [iname], [ginput.name])) + + onx = self._make_model_proto([node], vinputs, voutputs) + sess = OnnxruntimeEvaluator( + onx, + local_functions=self.local_functions, + verbose=self.verbose, + ir_version=self.ir_version, + opsets=self.opsets, + whole=True, + **self.session_kwargs, + ) + return onx, sess + + def _run_scan( + self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] + ) -> List[Any]: + """Runs a node Scan.""" + feeds = dict(zip(node.input, inputs)) + feeds.update(results) + name = "body" + key = (id(node), name) + if key in self._cache: + sess = self._cache[key][1] + else: + self._cache[key] = onx, sess = self._get_sess_scan(node, name, inputs, results) + + assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" + outputs = sess.run(None, feeds) + assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" + return outputs + def _run_local( self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] ) -> List[Any]: diff --git a/requirements-dev.txt b/requirements-dev.txt index b7ff92fd..86463510 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ -accelerate # transformers/src/transformers/modeling_utils.py -> init_empty_weights missing if this package is not installed +accelerate # transformers/src/transformers/modeling_utils.py -> init_empty_weights missing if this package is not installed black diffusers>=0.30.0 +git+https://github.com/sdpython/experimental-experiment.git furo huggingface_hub matplotlib diff --git a/requirements.txt b/requirements.txt index 49e5de42..cb762603 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy onnx>=1.16.0 -onnxruntime +onnxruntime>=1.21 optree -torch>=2.6 +torch>=2.7 torch_geometric