From 89c6ee74c2c1ed986c42cd1217526255a74e69e6 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 7 May 2025 14:46:31 +0200 Subject: [PATCH 1/4] update --- .../ut_helpers/test_torch_test_helper.py | 29 ++++++++++ onnx_diagnostic/helpers/torch_test_helper.py | 58 +++++++++++++------ 2 files changed, 68 insertions(+), 19 deletions(-) diff --git a/_unittests/ut_helpers/test_torch_test_helper.py b/_unittests/ut_helpers/test_torch_test_helper.py index b6a5e86a..a45dfcbd 100644 --- a/_unittests/ut_helpers/test_torch_test_helper.py +++ b/_unittests/ut_helpers/test_torch_test_helper.py @@ -59,6 +59,35 @@ def forward(self, x, y): with steal_forward(model): model(*inputs) + @hide_stdout() + def test_steal_forward_multi(self): + class SubModel(torch.nn.Module): + def forward(self, x): + return x * x + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.s1 = SubModel() + self.s2 = SubModel() + + def forward(self, x, y): + return self.s1(x) + self.s2(y) + + inputs = torch.rand(3, 4), torch.rand(3, 4) + model = Model() + with steal_forward( + [ + ( + "main", + model, + ), + (" s1", model.s1), + (" s2", model.s2), + ] + ): + model(*inputs) + def test_replace_string_by_dynamic(self): example = { "input_ids": {0: "batch_size", 1: "sequence_length"}, diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index 2bc4f2e2..4bd9beac 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -1,6 +1,6 @@ import contextlib from collections.abc import Iterable -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch from .helper import string_type @@ -12,11 +12,13 @@ ) -def _forward_(*args, _f=None, _context=None, **kwargs): +def _forward_(*args, _f=None, _fprint=string_type, _prefix="", _context=None, **kwargs): assert _f is not None, "_f cannot be None" assert _context is not None, "_context cannot be None" + indent = " " * (len(_prefix) - len(_prefix.lstrip())) + _prefix = _prefix.lstrip() print( - f"---- stolen forward for class {_context['class_name']} " + f"{indent}+{_prefix} -- stolen forward for class {_context['class_name']} " f"-- iteration {_context['iteration']}" ) kws = dict( @@ -25,36 +27,54 @@ def _forward_(*args, _f=None, _context=None, **kwargs): ) if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting(): # torch.compiler.is_exporting requires torch>=2.7 - print(f" <- args={string_type(args, **kws)} --- kwargs={string_type(kwargs, **kws)}") + print(f"{indent} <- args={_fprint(args, **kws)} --- kwargs={_fprint(kwargs, **kws)}") res = _f(*args, **kwargs) if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting(): - print(" --") - print(f" -> {string_type(res, **kws)}") - print(".") + print(f"{indent} -> {_fprint(res, **kws)}") + print(f"{indent}-{_prefix}.") _context["iteration"] += 1 return res @contextlib.contextmanager -def steal_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False): +def steal_forward( + model: Union[ + Union[torch.nn.Module, Tuple[str, torch.nn.Module]], + List[Union[torch.nn.Module, Tuple[str, torch.nn.Module]]], + ], + fprint: Callable = string_type, + **kwargs, +): """ The necessary modification to steem forward method and prints out inputs and outputs. See example :ref:`l-plot-tiny-llm-export`. + + :param model: a model or a list of models to monitor, + every model can also be a tuple(name, model), name is displayed well. + :param fprint: function used to print out (or dump), by default, it is + :func:`onnx_diagnostic.helpers.string_type` + :param kwargs: additional parameters sent to :func:`onnx_diagnostic.helpers.string_type` + or any other function defined by ``fprint`` """ - context = dict( - iteration=0, - class_name=model.__class__.__name__, - with_shape=with_shape, - with_min_max=with_min_max, - ) - keep_model_forward = model.forward - model.forward = lambda *args, _f=keep_model_forward, _context=context, **kwargs: _forward_( - *args, _f=_f, _context=_context, **kwargs - ) + context = dict(iteration=0, **kwargs) + if "with_shape" not in context and fprint == string_type: + context["with_shape"] = True + if not isinstance(model, list): + model = [model] + keep_model_forward = {} + for mt in model: + name, m = mt if isinstance(mt, tuple) else ("", mt) + keep_model_forward[id(m)] = (m, m.forward) + c = context.copy() + c["class_name"] = m.__class__.__name__ + m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, **kws: _forward_( + *args, _f=_f, _fprint=_fp, _context=_c, _prefix=_p, **kws + ) try: yield finally: - model.forward = keep_model_forward + for f in keep_model_forward.values(): + f[0].forward = f[1] def is_torchdynamo_exporting() -> bool: From ab387bf9b3780c9e6ea4106761343fa3b6ed4af1 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 7 May 2025 15:03:48 +0200 Subject: [PATCH 2/4] changelogs --- CHANGELOGS.rst | 1 + onnx_diagnostic/helpers/mini_onnx_builder.py | 36 +++++++++----------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index bae4cfe6..126883f7 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.5.0 +++++ +* :pr:`88`: extends ``steal_forward`` to dump input, outputs in onnx models * :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test) 0.4.4 diff --git a/onnx_diagnostic/helpers/mini_onnx_builder.py b/onnx_diagnostic/helpers/mini_onnx_builder.py index 2dd77d5b..a4de751b 100644 --- a/onnx_diagnostic/helpers/mini_onnx_builder.py +++ b/onnx_diagnostic/helpers/mini_onnx_builder.py @@ -310,7 +310,7 @@ def to_onnx(self) -> ModelProto: return model -def flatten_iterator(obj: Any, sep: str) -> Iterator: +def _flatten_iterator(obj: Any, sep: str) -> Iterator: """Iterates on all object.""" if obj is not None: if isinstance(obj, np.ndarray): @@ -329,10 +329,10 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator: else: for i, o in enumerate(obj): if i == len(obj) - 1: - for p, oo in flatten_iterator(o, sep): + for p, oo in _flatten_iterator(o, sep): yield f"tuple.{sep}{p}", oo else: - for p, oo in flatten_iterator(o, sep): + for p, oo in _flatten_iterator(o, sep): yield f"tuple{sep}{p}", oo elif isinstance(obj, list): if not obj: @@ -340,10 +340,10 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator: else: for i, o in enumerate(obj): if i == len(obj) - 1: - for p, oo in flatten_iterator(o, sep): + for p, oo in _flatten_iterator(o, sep): yield f"list.{sep}{p}", oo else: - for p, oo in flatten_iterator(o, sep): + for p, oo in _flatten_iterator(o, sep): yield f"list{sep}{p}", oo elif isinstance(obj, dict): if not obj: @@ -352,13 +352,13 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator: for i, (k, v) in enumerate(obj.items()): assert sep not in k, ( f"Key {k!r} cannot contain '{sep}'. " - f"It would interfer with the serialization." + f"It would interfere with the serialization." ) if i == len(obj) - 1: - for p, o in flatten_iterator(v, sep): + for p, o in _flatten_iterator(v, sep): yield f"dict._{k}{sep}{p}", o else: - for p, o in flatten_iterator(v, sep): + for p, o in _flatten_iterator(v, sep): yield f"dict_{k}{sep}{p}", o elif obj.__class__.__name__ == "DynamicCache": # transformers @@ -370,10 +370,10 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator: atts = ["key_cache", "value_cache"] for i, att in enumerate(atts): if i == len(atts) - 1: - for p, o in flatten_iterator(getattr(obj, att), sep): + for p, o in _flatten_iterator(getattr(obj, att), sep): yield f"DynamicCache._{att}{sep}{p}", o else: - for p, o in flatten_iterator(getattr(obj, att), sep): + for p, o in _flatten_iterator(getattr(obj, att), sep): yield f"DynamicCache_{att}{sep}{p}", o else: raise NotImplementedError(f"Unexpected type {type(obj)}") @@ -403,7 +403,7 @@ def create_onnx_model_from_input_tensors( switch_low_high = sys.byteorder != "big" builder = MiniOnnxBuilder(sep=sep) - for prefix, o in flatten_iterator(inputs, sep): + for prefix, o in _flatten_iterator(inputs, sep): if o is None: builder.append_output_initializer(prefix, np.array([])) else: @@ -413,7 +413,7 @@ def create_onnx_model_from_input_tensors( return model -def unflatten( +def _unflatten( sep: str, names: List[str], outputs: List[Any], @@ -421,9 +421,7 @@ def unflatten( level: int = 0, device: str = "cpu", ) -> Tuple[int, Tuple[Any, ...]]: - """ - Unflattens a list of outputs flattened with :func:`flatten_iterator`. - """ + """Unflattens a list of outputs flattened with :func:`flatten_iterator`.""" name = names[pos] spl = name.split(sep) if len(spl) == level + 1: @@ -448,7 +446,7 @@ def unflatten( name = names[pos] spl = name.split(sep) prefix = spl[level] - next_pos, value = unflatten( + next_pos, value = _unflatten( sep, names, outputs, pos=pos, level=level + 1, device=device ) @@ -499,7 +497,7 @@ def create_input_tensors_from_onnx_model( device: str = "cpu", engine: str = "ExtendedReferenceEvaluator", sep: str = "___", -) -> Union[Tuple[Any, ...], Dict[str, Any]]: +) -> Any: """ Deserializes tensors stored with function :func:`create_onnx_model_from_input_tensors`. @@ -511,7 +509,7 @@ def create_input_tensors_from_onnx_model( :param device: moves the tensor to this device :param engine: runtime to use, onnx, the default value, onnxruntime :param sep: separator - :return: ModelProto + :return: restored data """ if engine == "ExtendedReferenceEvaluator": from ..reference import ExtendedReferenceEvaluator @@ -552,4 +550,4 @@ def create_input_tensors_from_onnx_model( return torch.from_numpy(output).to(device) raise AssertionError(f"Unexpected name {name!r} in {names}") - return unflatten(sep, names, got, device=device)[1] + return _unflatten(sep, names, got, device=device)[1] From 85960d66170ed9f70c6b6598a8ae9e014dedb4ec Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 7 May 2025 15:36:48 +0200 Subject: [PATCH 3/4] dump stolen --- .../ut_helpers/test_torch_test_helper.py | 56 +++++++++++++++++++ onnx_diagnostic/helpers/mini_onnx_builder.py | 50 ++++++++++++----- onnx_diagnostic/helpers/torch_test_helper.py | 29 +++++++++- 3 files changed, 118 insertions(+), 17 deletions(-) diff --git a/_unittests/ut_helpers/test_torch_test_helper.py b/_unittests/ut_helpers/test_torch_test_helper.py index a45dfcbd..b03a481c 100644 --- a/_unittests/ut_helpers/test_torch_test_helper.py +++ b/_unittests/ut_helpers/test_torch_test_helper.py @@ -20,6 +20,7 @@ make_mamba_cache, make_sliding_window_cache, ) +from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model TFLOAT = onnx.TensorProto.FLOAT @@ -88,6 +89,61 @@ def forward(self, x, y): ): model(*inputs) + @hide_stdout() + def test_steal_forward_dump_file(self): + class SubModel(torch.nn.Module): + def forward(self, x): + return x * x + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.s1 = SubModel() + self.s2 = SubModel() + + def forward(self, x, y): + return self.s1(x) + self.s2(y) + + inputs = torch.rand(3, 4), torch.rand(3, 4) + model = Model() + dump_file = self.get_dump_file("test_steal_forward_dump_file.onnx") + with steal_forward( + [ + ( + "main", + model, + ), + (" s1", model.s1), + (" s2", model.s2), + ], + dump_file=dump_file, + ): + res1 = model(*inputs) + res2 = model(*inputs) + self.assertExists(dump_file) + restored = create_input_tensors_from_onnx_model(dump_file) + self.assertEqual( + [ + ("main", 0, "I"), + ("main", 0, "O"), + ("main", 1, "I"), + ("main", 1, "O"), + ("s1", 0, "I"), + ("s1", 0, "O"), + ("s1", 1, "I"), + ("s1", 1, "O"), + ("s2", 0, "I"), + ("s2", 0, "O"), + ("s2", 1, "I"), + ("s2", 1, "O"), + ], + sorted(restored), + ) + self.assertEqualAny(restored["main", 0, "I"], (inputs, {})) + self.assertEqualAny(restored["main", 1, "I"], (inputs, {})) + self.assertEqualAny(restored["main", 0, "O"], res1) + self.assertEqualAny(restored["main", 0, "O"], res2) + def test_replace_string_by_dynamic(self): example = { "input_ids": {0: "batch_size", 1: "sequence_length"}, diff --git a/onnx_diagnostic/helpers/mini_onnx_builder.py b/onnx_diagnostic/helpers/mini_onnx_builder.py index a4de751b..363d8f41 100644 --- a/onnx_diagnostic/helpers/mini_onnx_builder.py +++ b/onnx_diagnostic/helpers/mini_onnx_builder.py @@ -2,7 +2,7 @@ import sys from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import numpy as np -from onnx import GraphProto, ModelProto, TensorProto +from onnx import GraphProto, ModelProto, NodeProto, TensorProto import onnx.helper as oh import torch from .onnx_helper import dtype_to_tensor_dtype, tensor_dtype_to_np_dtype, from_array_extended @@ -34,10 +34,7 @@ def proto_from_array( ) # arr.contiguous() is slow after a transpose, maybe there is a way to optimize this. - if arr.is_contiguous(): - arr_cpu = arr.cpu() - else: - arr_cpu = arr.contiguous().cpu() + arr_cpu = arr.cpu() if arr.is_contiguous() else arr.contiguous().cpu() numel = torch.numel(arr_cpu) element_size = arr_cpu.element_size() @@ -91,10 +88,10 @@ class MiniOnnxBuilder: """ def __init__(self, target_opset: int = 18, ir_version: int = 10, sep: str = "___"): - self.initializers_dict = {} - self.inputs = [] - self.outputs = [] - self.nodes = [] + self.initializers_dict: Dict[str, Any] = {} + self.inputs: List[Any] = [] + self.outputs: List[Any] = [] + self.nodes: List[NodeProto] = [] self.opsets = {"": target_opset} self.ir_version = ir_version self.torch = torch @@ -270,7 +267,7 @@ def _build_initializers( return initializer - res = [] + res: List[TensorProto] = [] for k, v in init_dict.items(): if isinstance(v, TensorProto): res.append(v) @@ -354,12 +351,19 @@ def _flatten_iterator(obj: Any, sep: str) -> Iterator: f"Key {k!r} cannot contain '{sep}'. " f"It would interfere with the serialization." ) + + def _mk(k): + if isinstance(k, tuple): + # this assumes the tuple contains simple types + return f"(({','.join(map(str,k))}))" + return str(k) + if i == len(obj) - 1: for p, o in _flatten_iterator(v, sep): - yield f"dict._{k}{sep}{p}", o + yield f"dict._{_mk(k)}{sep}{p}", o else: for p, o in _flatten_iterator(v, sep): - yield f"dict_{k}{sep}{p}", o + yield f"dict_{_mk(k)}{sep}{p}", o elif obj.__class__.__name__ == "DynamicCache": # transformers import transformers @@ -420,7 +424,7 @@ def _unflatten( pos: int = 0, level: int = 0, device: str = "cpu", -) -> Tuple[int, Tuple[Any, ...]]: +) -> Tuple[int, Any]: """Unflattens a list of outputs flattened with :func:`flatten_iterator`.""" name = names[pos] spl = name.split(sep) @@ -465,7 +469,7 @@ def _unflatten( if end: if prefix.startswith("dict"): - ty = dict + ty: type = dict elif prefix.startswith("list"): ty = list elif prefix.startswith("tuple"): @@ -479,12 +483,30 @@ def _unflatten( break pos = next_pos + def _tryint(s): + try: + return int(s) + except (ValueError, TypeError): + if s in {"True", "False"}: + return s == "True" + return s + def _make(ty: type, res: Any) -> Any: if ty.__name__ == "DynamicCache": r = ty() for k, v in res: setattr(r, k, v) return r + if ty is dict: + d = {} + for k, v in res: + if k.startswith("((") and k.endswith("))"): + spl = k[2:-2].split(",") + key = tuple(_tryint(s) for s in spl) + else: + key = _tryint(k) + d[key] = v + return d return ty(res) return next_pos, ( diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index 4d5e19e6..5b4f1e1a 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -2,6 +2,7 @@ from collections.abc import Iterable from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np +import onnx import torch from .helper import string_type from .cache_helper import ( @@ -10,9 +11,12 @@ make_sliding_window_cache, make_mamba_cache, ) +from .mini_onnx_builder import create_onnx_model_from_input_tensors -def _forward_(*args, _f=None, _fprint=string_type, _prefix="", _context=None, **kwargs): +def _forward_( + *args, _f=None, _fprint=string_type, _prefix="", _context=None, _storage=None, **kwargs +): assert _f is not None, "_f cannot be None" assert _context is not None, "_context cannot be None" indent = " " * (len(_prefix) - len(_prefix.lstrip())) @@ -28,10 +32,16 @@ def _forward_(*args, _f=None, _fprint=string_type, _prefix="", _context=None, ** if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting(): # torch.compiler.is_exporting requires torch>=2.7 print(f"{indent} <- args={_fprint(args, **kws)} --- kwargs={_fprint(kwargs, **kws)}") + if _storage is not None: + it = _context["iteration"] + key = (_prefix, it) + _storage[(*key, "I")] = (torch_deepcopy(args), torch_deepcopy(kwargs)) res = _f(*args, **kwargs) if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting(): print(f"{indent} -> {_fprint(res, **kws)}") print(f"{indent}-{_prefix}.") + if _storage is not None: + _storage[(*key, "O")] = torch_deepcopy(res) _context["iteration"] += 1 return res @@ -43,6 +53,7 @@ def steal_forward( List[Union[torch.nn.Module, Tuple[str, torch.nn.Module]]], ], fprint: Callable = string_type, + dump_file: Optional[str] = None, **kwargs, ): """ @@ -56,6 +67,9 @@ def steal_forward( :func:`onnx_diagnostic.helpers.string_type` :param kwargs: additional parameters sent to :func:`onnx_diagnostic.helpers.string_type` or any other function defined by ``fprint`` + :param dump_file: dumps stolen inputs and outputs in an onnx model, + they can be restored with :func:`create_input_tensors_from_onnx_model + ` """ context = dict(iteration=0, **kwargs) if "with_shape" not in context and fprint == string_type: @@ -63,19 +77,28 @@ def steal_forward( if not isinstance(model, list): model = [model] keep_model_forward = {} + storage = {} if dump_file else None for mt in model: name, m = mt if isinstance(mt, tuple) else ("", mt) keep_model_forward[id(m)] = (m, m.forward) c = context.copy() c["class_name"] = m.__class__.__name__ - m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, **kws: _forward_( - *args, _f=_f, _fprint=_fp, _context=_c, _prefix=_p, **kws + m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, **kws: _forward_( # noqa: E501 + *args, _f=_f, _fprint=_fp, _context=_c, _prefix=_p, _storage=_s, **kws ) try: yield finally: for f in keep_model_forward.values(): f[0].forward = f[1] + if dump_file: + proto = create_onnx_model_from_input_tensors(storage) + onnx.save( + proto, + dump_file, + save_as_external_data=False, + all_tensors_to_one_file=True, + ) def is_torchdynamo_exporting() -> bool: From b86c7bd8d7f67f5bffbb5ccf95d255d37c9d346d Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 7 May 2025 15:45:21 +0200 Subject: [PATCH 4/4] fix type --- onnx_diagnostic/helpers/mini_onnx_builder.py | 17 +++++++---------- onnx_diagnostic/helpers/torch_test_helper.py | 4 ++-- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/onnx_diagnostic/helpers/mini_onnx_builder.py b/onnx_diagnostic/helpers/mini_onnx_builder.py index 363d8f41..8d868d08 100644 --- a/onnx_diagnostic/helpers/mini_onnx_builder.py +++ b/onnx_diagnostic/helpers/mini_onnx_builder.py @@ -94,7 +94,6 @@ def __init__(self, target_opset: int = 18, ir_version: int = 10, sep: str = "___ self.nodes: List[NodeProto] = [] self.opsets = {"": target_opset} self.ir_version = ir_version - self.torch = torch self.sep = sep def append_output_initializer( @@ -163,7 +162,7 @@ def append_output_sequence( ) else: assert all( - isinstance(t, (np.ndarray, self.torch.Tensor)) for t in tensors + isinstance(t, (np.ndarray, torch.Tensor)) for t in tensors ), f"Nested sequences are not supported, types are {[type(t) for t in tensors]}" names = [] for i, t in enumerate(tensors): @@ -197,9 +196,7 @@ def append_output_dict( self.append_output_initializer(f"{name}{self.sep}keys", np.array(keys, dtype=np.str_)) self.append_output_sequence(f"{name}{self.sep}values", values) - def _build_initializers( - self, switch_low_high: bool - ) -> Tuple[List[TensorProto], Dict[str, TensorProto]]: + def _build_initializers(self, switch_low_high: bool) -> List[TensorProto]: """ Builds initializers. @@ -209,7 +206,7 @@ def _build_initializers( init_dict = self.initializers_dict if switch_low_high: # Let's try to minimize the time. - initializer = [] + initializer: List[TensorProto] = [] for k, v in init_dict.items(): if isinstance(v, TensorProto): initializer.append(v) @@ -245,7 +242,7 @@ def _build_initializers( continue else: assert isinstance( - v, self.torch.Tensor + v, torch.Tensor ), f"tensor {k!r} has un unexpected type {type(v)}" assert "FakeTensor" not in str( type(v) @@ -272,9 +269,9 @@ def _build_initializers( if isinstance(v, TensorProto): res.append(v) continue - if isinstance(v, self.torch.Tensor): + if isinstance(v, torch.Tensor): # no string tensor - t = self.from_array(v, name=k) + t = proto_from_array(v, name=k) res.append(t) continue if isinstance(v, np.ndarray): @@ -444,7 +441,7 @@ def _unflatten( return pos + 1, torch.from_numpy(outputs[pos]).to(device) raise AssertionError(f"Unexpected name {name!r} in {names}") - res = [] + res: List[Any] = [] while True: assert pos < len(names), f"Something went wrong with names={names!r}\nres={res!r}" name = names[pos] diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index 5b4f1e1a..90cee379 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -1,6 +1,6 @@ import contextlib from collections.abc import Iterable -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import onnx import torch @@ -77,7 +77,7 @@ def steal_forward( if not isinstance(model, list): model = [model] keep_model_forward = {} - storage = {} if dump_file else None + storage: Optional[Dict[Any, Any]] = {} if dump_file else None for mt in model: name, m = mt if isinstance(mt, tuple) else ("", mt) keep_model_forward[id(m)] = (m, m.forward)