diff --git a/_unittests/ut_helpers/test_torch_test_helper.py b/_unittests/ut_helpers/test_torch_test_helper.py index dc169d03..e53ee65a 100644 --- a/_unittests/ut_helpers/test_torch_test_helper.py +++ b/_unittests/ut_helpers/test_torch_test_helper.py @@ -15,6 +15,7 @@ replace_string_by_dynamic, to_any, torch_deepcopy, + torch_tensor_size, ) from onnx_diagnostic.helpers.cache_helper import ( make_dynamic_cache, @@ -204,7 +205,7 @@ def forward(self, x, y): else: print("output", k, v) print(string_type(restored, with_shape=True)) - l1, l2 = 182, 191 + l1, l2 = 183, 192 self.assertEqual( [ (f"-Model-{l2}", 0, "I"), @@ -264,6 +265,7 @@ def test_torch_deepcopy_cache_dce(self): c1.key_cache[0] += 1000 hash2 = string_type(at, with_shape=True, with_min_max=True) self.assertEqual(hash1, hash2) + self.assertGreater(torch_tensor_size(cc), 1) def test_torch_deepcopy_mamba_cache(self): cache = make_mamba_cache( @@ -280,6 +282,7 @@ def test_torch_deepcopy_mamba_cache(self): cache.conv_states[0] += 1000 hash2 = string_type(at, with_shape=True, with_min_max=True) self.assertEqual(hash1, hash2) + self.assertGreater(torch_tensor_size(cache), 1) def test_torch_deepcopy_base_model_outputs(self): bo = transformers.modeling_outputs.BaseModelOutput( @@ -292,6 +295,7 @@ def test_torch_deepcopy_base_model_outputs(self): bo.last_hidden_state[0] += 1000 hash2 = string_type(at, with_shape=True, with_min_max=True) self.assertEqual(hash1, hash2) + self.assertGreater(torch_tensor_size(bo), 1) def test_torch_deepcopy_sliding_windon_cache(self): cache = make_sliding_window_cache( @@ -308,9 +312,11 @@ def test_torch_deepcopy_sliding_windon_cache(self): cache.key_cache[0] += 1000 hash2 = string_type(at, with_shape=True, with_min_max=True) self.assertEqual(hash1, hash2) + self.assertGreater(torch_tensor_size(cache), 1) def test_torch_deepcopy_none(self): self.assertEmpty(torch_deepcopy(None)) + self.assertEqual(torch_tensor_size(None), 0) def test_model_statistics(self): class Model(torch.nn.Module): diff --git a/_unittests/ut_reference/test_onnxruntime_evaluator.py b/_unittests/ut_reference/test_onnxruntime_evaluator.py index 132009ca..7f255f2d 100644 --- a/_unittests/ut_reference/test_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_onnxruntime_evaluator.py @@ -1,8 +1,11 @@ import unittest +import numpy as np import onnx +import onnx.helper as oh import torch import onnxruntime from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.helpers.onnx_helper import from_array_extended from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator try: @@ -96,6 +99,97 @@ def false_fn(x, y): for e, g in zip(expected, got): self.assertEqualArray(e, g, atol=1e-5) + def test_constant_bool(self): + node = oh.make_node( + "Constant", + [], + ["cbool"], + value=from_array_extended(np.array(True, dtype=np.bool_)), + ) + ref = ExtendedReferenceEvaluator(node) + got = ref.run(None, {})[0] + self.assertEqual(got.dtype, np.bool_) + self.assertEqual(got, True) + ref = OnnxruntimeEvaluator(node, opsets=21) + got = ref.run(None, {})[0] + self.assertEqual(len(ref._cache), 1) + values = list(ref._cache.values()) + _, sess = values[0] + got2 = sess.run(None, {})[0] + self.assertIn(got2.dtype, (torch.bool, np.bool_)) + self.assertEqual(got2, True) + + self.assertIn(got.dtype, (torch.bool, np.bool_)) + self.assertEqual(got, True) + + def test_constant_bool_array(self): + node = oh.make_node( + "Constant", + [], + ["cbool"], + value=from_array_extended(np.array([True], dtype=np.bool_)), + ) + ref = ExtendedReferenceEvaluator(node) + got = ref.run(None, {})[0] + self.assertEqual(got.dtype, np.bool_) + self.assertEqual(got[0], True) + ref = OnnxruntimeEvaluator(node, opsets=21) + got = ref.run(None, {})[0] + self.assertEqual(len(ref._cache), 1) + values = list(ref._cache.values()) + _, sess = values[0] + got2 = sess.run(None, {})[0] + self.assertIn(got2.dtype, (torch.bool, np.bool_)) + self.assertEqual(got2[0], True) + + self.assertIn(got.dtype, (torch.bool, np.bool_)) + self.assertEqual(got[0], True) + + def test_constant_bool_input(self): + node = oh.make_model( + oh.make_graph( + [oh.make_node("Identity", ["bin"], ["bout"])], + "test", + [oh.make_tensor_value_info("bin", onnx.TensorProto.BOOL, [1])], + [oh.make_tensor_value_info("bin", onnx.TensorProto.BOOL, [1])], + ), + ir_version=10, + opset_imports=[oh.make_opsetid("", 18)], + ) + feeds = dict(bin=np.array([True], dtype=np.bool_)) + ref = ExtendedReferenceEvaluator(node) + + got = ref.run(None, feeds)[0] + self.assertEqual(got.dtype, np.bool_) + self.assertEqual(got[0], True) + + ref = OnnxruntimeEvaluator(node, opsets=21) + got = ref.run(None, feeds)[0] + self.assertEqual(got.dtype, np.bool_) + self.assertEqual(got[0], True) + + feeds = dict(bin=torch.tensor([True], dtype=torch.bool)) + got = ref.run(None, feeds)[0] + self.assertEqual(got.dtype, torch.bool) + self.assertEqual(got[0], True) + + def test_ort_eval_loop(self): + model = torch.nn.EmbeddingBag(num_embeddings=49157, embedding_dim=32, mode="sum") + a = torch.tensor([[39906, 39906]]).long() + example_args = (a,) + model_eval = model.eval() + expected = model(*example_args) + + onx = to_onnx(model_eval, example_args, optimize=True) + self.assertIn("Loop", set(n.op_type for n in onx.graph.node)) + + ref = OnnxruntimeEvaluator(onx, verbose=10) + feeds = dict( + zip([i.name for i in onx.graph.input], [t.detach().numpy() for t in example_args]) + ) + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index a56fabac..79937e6b 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -109,7 +109,7 @@ def __init__( ) except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e: if isinstance(sess, onnx.ModelProto): - debug_path = "_debug_onnxruntine_evaluator_failure.onnx" + debug_path = "_debug_InferenceSession_last_failure.onnx" onnx.save( sess, debug_path, @@ -154,6 +154,7 @@ def __init__( ) self._torch_from_dlpack = _from_dlpack + self.sess_bool_outputs = [i.type == "tensor(bool)" for i in sess.get_outputs()] def run( self, @@ -166,7 +167,19 @@ def run( ort_outputs = self.sess._sess.run_with_ort_values( feeds, output_names or self.output_names, self.run_options ) - return ort_outputs + return self._post_process_inplace(ort_outputs) + + def _post_process_inplace(self, outputs): + for i in range(len(outputs)): + o = outputs[i] + if self.sess_bool_outputs[i]: + if isinstance(o, np.ndarray): + if o.dtype != np.bool_: + outputs[i] = o.astype(np.bool_) + else: + if o.dtype != torch.bool: + outputs[i] = o.to(torch.bool) + return outputs class InferenceSessionForNumpy(_InferenceSession): @@ -221,7 +234,7 @@ def run( """Calls :meth:`onnxruntime.InferenceSession.run`.""" # sess.run does not support blfoat16 # res = self.sess.run(output_names, feeds) - return list(self.run_dlpack(output_names, feeds)) + return self._post_process_inplace(list(self.run_dlpack(output_names, feeds))) def run_dlpack( self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike] @@ -231,17 +244,23 @@ def run_dlpack( feeds is a dictionary of :class:`np.ndarray`. The output device is CPU even if the outputs are on CUDA. """ + memory = [] new_feeds = {} for k, v in feeds.items(): if not k: continue - new_feeds[k] = ( - ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type( + if isinstance(v, np.ndarray): + new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type( v, np_dtype_to_tensor_dtype(v.dtype) ) - if isinstance(v, np.ndarray) - else ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool) - ) + elif v.dtype == torch.bool: + vi = v.detach().cpu().numpy() + memory.append(vi) + new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type( + vi, onnx.TensorProto.BOOL + ) + else: + new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False) if self.nvtx: self.torch.cuda.nvtx.range_push("run_with_ort_values") @@ -421,7 +440,7 @@ def run( # type: ignore if self.use_training_api: inputs = [feeds[i] for i in self.input_names] return self.run_training_api(*inputs, output_names=output_names) - return self.run_dlpack(output_names, feeds) + return self._post_process_inplace(list(self.run_dlpack(output_names, feeds))) def run_training_api( self, *inputs, output_names: Optional[List[str]] = None @@ -471,7 +490,14 @@ def run_dlpack( assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized" if not v.is_contiguous(): v = v.contiguous() - new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool) + if v.dtype == torch.bool: + # It does not work with dlpack + # unless onnxruntime updates the version it is using. + new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type( + v.detach().numpy(), onnx.TensorProto.BOOL + ) + else: + new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False) if self.nvtx: self.torch.cuda.nvtx.range_push("run_with_ort_values") ort_outputs = self.sess._sess.run_with_ort_values( diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index 85b3cdec..d0c1fcb2 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -1,11 +1,12 @@ import contextlib import inspect +import os from collections.abc import Iterable from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import onnx import torch -from .helper import string_type +from .helper import string_type, size_type from .cache_helper import ( make_dynamic_cache, make_encoder_decoder_cache, @@ -16,7 +17,15 @@ def _forward_( - *args, _f=None, _fprint=string_type, _prefix="", _context=None, _storage=None, **kwargs + *args, + _f=None, + _fprint=string_type, + _prefix="", + _context=None, + _storage=None, + _storage_limit=2**27, + _verbose=0, + **kwargs, ): assert _f is not None, "_f cannot be None" assert _context is not None, "_context cannot be None" @@ -42,7 +51,20 @@ def _forward_( print(f"{indent} -> {_fprint(res, **kws)}") print(f"{indent}-{_prefix}.") if _storage is not None: - _storage[(*key, "O")] = torch_deepcopy(res) + size = torch_tensor_size(res) + if size < _storage_limit: + if _verbose: + print( + f"-- stores key={key}, size {size // 2**10}Kb -- " + f"{string_type(res, with_shape=True)}" + ) + _storage[(*key, "O")] = torch_deepcopy(res) + else: + if _verbose: + print( + f"-- skips key={key}, size {size // 2**10}Kb -- " + f"{string_type(res, with_shape=True)}" + ) _context["iteration"] += 1 return res @@ -92,6 +114,8 @@ def steal_forward( fprint: Callable = string_type, dump_file: Optional[str] = None, submodules: bool = False, + verbose: int = 0, + storage_limit: int = 2**27, **kwargs, ): """ @@ -110,6 +134,8 @@ def steal_forward( ` :param submodules: if True and model is a module, the list extended with all the submodules the module contains + :param verbose: verbosity + :param storage_limit: do not stored object bigger than this The following examples shows how to steal and dump all the inputs / outputs for a module and its submodules, then restores them. @@ -181,8 +207,16 @@ def forward(self, x, y): keep_model_forward[id(m)] = (m, m.forward) c = context.copy() c["class_name"] = m.__class__.__name__ - m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, **kws: _forward_( # noqa: E501 - *args, _f=_f, _fprint=_fp, _context=_c, _prefix=_p, _storage=_s, **kws + m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, _v=verbose, _sl=storage_limit, **kws: _forward_( # noqa: E501 + *args, + _f=_f, + _fprint=_fp, + _context=_c, + _prefix=_p, + _storage=_s, + _verbose=_v, + _storage_limit=_sl, + **kws, ) try: yield @@ -196,13 +230,21 @@ def forward(self, x, y): storage.update(_additional_stolen_objects) # We clear the cache. _additional_stolen_objects.clear() + if verbose: + size = torch_tensor_size(storage) + print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb") proto = create_onnx_model_from_input_tensors(storage) + if verbose: + print("-- dumps stored objects") onnx.save( proto, dump_file, save_as_external_data=True, all_tensors_to_one_file=True, + location=f"{os.path.split(dump_file)[-1]}.data", ) + if verbose: + print("-- done dump stored objects") @contextlib.contextmanager @@ -552,6 +594,37 @@ def torch_deepcopy(value: Any) -> Any: raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}") +def torch_tensor_size(value: Any) -> Any: + """Returns the number of bytes stored in tensors.""" + if value is None: + return 0 + if isinstance(value, (int, float, str)): + return 0 + if isinstance(value, (tuple, list, set)): + return sum(torch_tensor_size(v) for v in value) + if isinstance(value, dict): + return sum(torch_tensor_size(v) for v in value.values()) + if isinstance(value, np.ndarray): + return value.copy() + if hasattr(value, "clone"): + return value.numel() * size_type(value.dtype) + if value.__class__.__name__ in {"DynamicCache", "SlidingWindowCache"}: + return torch_tensor_size(value.key_cache) + torch_tensor_size(value.value_cache) + if value.__class__.__name__ == "EncoderDecoderCache": + return torch_tensor_size(value.self_attention_cache) + torch_tensor_size( + value.cross_attention_cache + ) + if value.__class__.__name__ == "MambaCache": + return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states) + if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: + args, spec = torch.utils._pytree.tree_flatten(value) + return sum(torch_tensor_size(a) for a in args) + + # We should have a code using serialization, deserialization assuming a model + # cannot be exported without them. + raise NotImplementedError(f"torch_tensor_size not implemented for type {type(value)}") + + def model_statistics(model: torch.nn.Module): """Returns statistics on a model in a dictionary.""" n_subs = len(list(model.modules())) diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index c683c81a..3831cd62 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import numpy as np from onnx import ( + AttributeProto, GraphProto, FunctionProto, ModelProto, @@ -9,6 +10,8 @@ ValueInfoProto, helper as oh, load, + save as onnx_save, + shape_inference as shi, ) from onnx.defs import onnx_opset_version import onnxruntime @@ -19,6 +22,7 @@ InferenceSessionForNumpy, _InferenceSession, ) +from .evaluator import ExtendedReferenceEvaluator PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto) Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto] @@ -131,6 +135,7 @@ def __init__( ) if local_functions: self.local_functions.update(local_functions) + self.garbage_collector = self._build_garbage_collector() if self.rt_nodes_ else {} @property def input_names(self) -> List[str]: @@ -238,7 +243,7 @@ def run( self._log(2, " +I %s: %s", k, v) results[k] = v - for node in self.rt_nodes_ or []: + for i_node, node in enumerate(self.rt_nodes_ or []): self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output) for i in node.input: if i != "" and i not in results: @@ -250,7 +255,7 @@ def run( inputs = [(results[i] if i != "" else None) for i in node.input] if node.op_type == "If" and node.domain == "": outputs = self._run_if(node, inputs, results) - elif node.op_type == "Scan" and node.domain == "": + elif node.op_type in {"Scan", "Loop"} and node.domain == "": outputs = self._run_scan(node, inputs, results) elif self._is_local_function(node): outputs = self._run_local(node, inputs, results) @@ -262,6 +267,8 @@ def run( self._log(2, " + %s: %s", name, value) # type: ignore[arg-type] assert isinstance(name, str), f"unexpected type for name {type(name)}" results[name] = value + if not intermediate: + self._clean_unused_inplace(i_node, node, results) if intermediate: return results @@ -276,6 +283,52 @@ def run( ) return [results[name] for name in output_names if name != ""] + def _build_garbage_collector(self) -> Dict[str, int]: + """ + Memorizes the results not needed anymore for every node. + Returns a dictionary with the last node using the results. + """ + needed = {} + for i, node in enumerate(self.rt_nodes_ or []): + for name in node.input: + needed[name] = i + if node.op_type in {"Scan", "If", "Loop"}: + hidden = self._get_hidden_node_inputs(node) + for name in hidden: + needed[name] = i + if isinstance(self.proto, ModelProto): + for o in self.proto.graph.output: + needed[o.name] = len(self.rt_nodes_ or []) + elif isinstance(self.proto, GraphProto): + for o in self.proto.output: + needed[o.name] = len(self.rt_nodes_ or []) + elif isinstance(self.proto, FunctionProto): + for o in self.proto.output: + needed[o] = len(self.rt_nodes_ or []) + return needed + + def _clean_unused_inplace(self, i_node: int, node: NodeProto, results: Dict[str, Any]): + """ + Cleans all results not needed anymore. Some models requires to clean the memory + to be able to run. + """ + if not self.garbage_collector: + return + for name in node.input: + if self.garbage_collector[name] == i_node and name in results: + if self.verbose: + t = results[name] + print(f" - deletes: {name} - {t.dtype}:{t.shape}") + del results[name] + if node.op_type in {"Scan", "If", "Loop"}: + hidden = self._get_hidden_node_inputs(node) + for name in hidden: + if self.garbage_collector[name] == i_node and name in results: + if self.verbose: + t = results[name] + print(f" - deletes: {name} - {t.dtype}:{t.shape}") + del results[name] + def _make_model_proto( self, nodes: Sequence[NodeProto], @@ -300,8 +353,42 @@ def _make_model_proto( else: onx.opset_import.append(oh.make_opsetid("", onnx_opset_version())) + # That helps fixing bugs. + onx = shi.infer_shapes(onx) return onx + @classmethod + def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]: + """ + Returns the hidden inputs (inputs coming from an upper context) + used by a subgraph. + """ + hidden = set() + memo = set(i.name for i in graph.initializer) + memo |= set(i.name for i in graph.sparse_initializer) + for node in graph.node: + for i in node.input: + if i not in memo: + hidden.add(i) + for att in node.attribute: + if att.type == AttributeProto.GRAPH and att.g: + hid = self._get_hidden_inputs(att.g) + less = set(h for h in hid if h not in memo) + hidden |= less + memo |= set(node.output) + return hidden + + @classmethod + def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]: + """Calls multiple _get_hidden_inputs on every attribute.""" + if node.op_type not in {"Loop", "Scan", "If"}: + return set() + hidden = set() + for att in node.attribute: + if att.type == AttributeProto.GRAPH: + hidden |= self._get_hidden_inputs(att.g) + return hidden - (hidden & set(node.input)) + def _get_sess( self, node: Union[ModelProto, NodeProto], inputs: List[Any] ) -> Tuple[ModelProto, _InferenceSession]: @@ -309,17 +396,31 @@ def _get_sess( onx = node else: assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node" - unique_names = set() - vinputs = [] - for i, it in zip(node.input, inputs): - if i == "" or i in unique_names: - continue - unique_names.add(i) - value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) - vinputs.append(value) + if node.op_type == "Constant": + # We force the type to be a boolean. + ref = ExtendedReferenceEvaluator(node) + cst = ref.run(None, {})[0] + vinputs: List[ValueInfoProto] = [] + voutputs = [ + oh.make_tensor_value_info( + node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape + ) + ] + else: + unique_names = set() + vinputs = [] + for i, it in zip(node.input, inputs): + if i == "" or i in unique_names: + continue + unique_names.add(i) + value = oh.make_tensor_value_info( + i, dtype_to_tensor_dtype(it.dtype), it.shape + ) + vinputs.append(value) + + # no need to run shape inference + voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output] - # no need to run shape inference - voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output] onx = self._make_model_proto([node], vinputs, voutputs) cls = ( @@ -334,6 +435,7 @@ def _get_sess( onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph, onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument, ) as e: + onnx_save(onx, "_debug_OnnxruntimeEvaluator_last_failure.onnx") raise RuntimeError( f"Unable to infer a session with inputs\n{string_type(inputs)}" f"\ndue to {e}\n{pretty_onnx(onx)}" @@ -341,7 +443,7 @@ def _get_sess( return onx, sess def _get_sess_init_subgraph( - self, node: NodeProto, inputs: List[Any], context: Dict[str, Any] + self, node: NodeProto, inputs: List[Any], context: Dict[str, Any], g: GraphProto ) -> List[Any]: unique_names = set() vinputs = [] @@ -352,8 +454,9 @@ def _get_sess_init_subgraph( value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) vinputs.append(value) + reduced_set = self._get_hidden_inputs(g) for i, v in context.items(): - if i not in unique_names: + if i in reduced_set and i not in unique_names: unique_names.add(i) value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape) vinputs.append(value) @@ -362,13 +465,12 @@ def _get_sess_init_subgraph( def _get_sess_if( self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any] ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]: - vinputs = self._get_sess_init_subgraph(node, inputs, context) - g = None for att in node.attribute: if att.name == branch: g = att.g assert g, f"Missing attribute {branch!r}" + vinputs = self._get_sess_init_subgraph(node, inputs, context, g) voutputs = g.output @@ -439,6 +541,7 @@ def _run_if( self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" + feeds = {name: results[name] for name in sess.input_names} outputs = sess.run(None, feeds) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs @@ -446,19 +549,18 @@ def _run_if( def _get_sess_scan( self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any] ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]: - vinputs = self._get_sess_init_subgraph(node, inputs, context) - g = None for att in node.attribute: if att.name == branch: g = att.g assert g, f"Missing attribute {branch!r}" + vinputs = self._get_sess_init_subgraph(node, inputs, context, g) + begin = 0 if node.op_type == "Scan" else 1 voutputs = [] - for name, goutput in zip(node.output, g.output): - b = goutput.SerializeToString() + for name, _goutput in zip(node.output, g.output[begin:]): v = ValueInfoProto() - v.ParseFromString(b) + # v.ParseFromString(goutput.SerializeToString()) v.name = name voutputs.append(v) @@ -492,6 +594,7 @@ def _run_scan( self._cache[key] = onx, sess = self._get_sess_scan(node, name, inputs, results) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" + feeds = {name: results[name] for name in sess.input_names} outputs = sess.run(None, feeds) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs