From 35fe9554c21623bee3d89c6f581e0ce6d9f9b3d8 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 7 May 2025 16:27:20 +0200 Subject: [PATCH 1/4] steal from submodules --- .../ut_helpers/test_torch_test_helper.py | 33 +++++++++++++++++++ onnx_diagnostic/helpers/torch_test_helper.py | 16 ++++++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_helpers/test_torch_test_helper.py b/_unittests/ut_helpers/test_torch_test_helper.py index b03a481c..ae37a420 100644 --- a/_unittests/ut_helpers/test_torch_test_helper.py +++ b/_unittests/ut_helpers/test_torch_test_helper.py @@ -144,6 +144,39 @@ def forward(self, x, y): self.assertEqualAny(restored["main", 0, "O"], res1) self.assertEqualAny(restored["main", 0, "O"], res2) + @hide_stdout() + def test_steal_forward_submodules(self): + class SubModel(torch.nn.Module): + def forward(self, x): + return x * x + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.s1 = SubModel() + self.s2 = SubModel() + + def forward(self, x, y): + return self.s1(x) + self.s2(y) + + inputs = torch.rand(3, 4), torch.rand(3, 4) + model = Model() + dump_file = self.get_dump_file("test_steal_forward_submodules.onnx") + with steal_forward(model, submodules=True, dump_file=dump_file): + model(*inputs) + restored = create_input_tensors_from_onnx_model(dump_file) + self.assertEqual( + [ + ("", 0, "I"), + ("", 0, "O"), + ("s1", 0, "I"), + ("s1", 0, "O"), + ("s2", 0, "I"), + ("s2", 0, "O"), + ], + sorted(restored), + ) + def test_replace_string_by_dynamic(self): example = { "input_ids": {0: "batch_size", 1: "sequence_length"}, diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index 90cee379..57e52c1e 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -54,6 +54,7 @@ def steal_forward( ], fprint: Callable = string_type, dump_file: Optional[str] = None, + submodules: bool = False, **kwargs, ): """ @@ -70,12 +71,25 @@ def steal_forward( :param dump_file: dumps stolen inputs and outputs in an onnx model, they can be restored with :func:`create_input_tensors_from_onnx_model ` + :param submodules: if True and model is a module, the list extended with all the submodules + the module contains """ + assert not submodules or isinstance( + model, torch.nn.Module + ), f"submodules can only be True if model is a module but is is {type(model)}." context = dict(iteration=0, **kwargs) if "with_shape" not in context and fprint == string_type: context["with_shape"] = True if not isinstance(model, list): - model = [model] + if submodules: + models = [] + for idx, m in model.named_modules(): + level = str(idx).split(".") + ll = len(level) + models.append((f"{' ' * ll}{idx}", m)) + model = models + else: + model = [model] keep_model_forward = {} storage: Optional[Dict[Any, Any]] = {} if dump_file else None for mt in model: From 51ca3a3596d4ca9d9376ce51c88cd8976533e13a Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 7 May 2025 16:27:46 +0200 Subject: [PATCH 2/4] changr --- CHANGELOGS.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 126883f7..47a01f9c 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,7 @@ Change Logs 0.5.0 +++++ -* :pr:`88`: extends ``steal_forward`` to dump input, outputs in onnx models +* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models * :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test) 0.4.4 From eb75c757714a6311a18df2f03e34116a6defb3bb Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 7 May 2025 16:29:30 +0200 Subject: [PATCH 3/4] mypy --- onnx_diagnostic/helpers/torch_test_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index 57e52c1e..fffb4e3a 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -81,6 +81,7 @@ def steal_forward( if "with_shape" not in context and fprint == string_type: context["with_shape"] = True if not isinstance(model, list): + assert isinstance(model, torch.nn.Module), f"Unexpected type {type(model)} for model" if submodules: models = [] for idx, m in model.named_modules(): From 1a6e874b6fa4a1ec7ada8a085ca3337a621523f1 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 7 May 2025 16:50:35 +0200 Subject: [PATCH 4/4] line --- .../ut_helpers/test_torch_test_helper.py | 19 ++++++--- onnx_diagnostic/helpers/torch_test_helper.py | 42 ++++++++++++++++++- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/_unittests/ut_helpers/test_torch_test_helper.py b/_unittests/ut_helpers/test_torch_test_helper.py index ae37a420..ccbef435 100644 --- a/_unittests/ut_helpers/test_torch_test_helper.py +++ b/_unittests/ut_helpers/test_torch_test_helper.py @@ -165,14 +165,21 @@ def forward(self, x, y): with steal_forward(model, submodules=True, dump_file=dump_file): model(*inputs) restored = create_input_tensors_from_onnx_model(dump_file) + for k, v in sorted(restored.items()): + if isinstance(v, tuple): + args, kwargs = v + print("input", k, args, kwargs) + else: + print("output", k, v) + print(string_type(restored, with_shape=True)) self.assertEqual( [ - ("", 0, "I"), - ("", 0, "O"), - ("s1", 0, "I"), - ("s1", 0, "O"), - ("s2", 0, "I"), - ("s2", 0, "O"), + ("-Model-159", 0, "I"), + ("-Model-159", 0, "O"), + ("s1-SubModel-150", 0, "I"), + ("s1-SubModel-150", 0, "O"), + ("s2-SubModel-150", 0, "I"), + ("s2-SubModel-150", 0, "O"), ], sorted(restored), ) diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index fffb4e3a..e3f60a68 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -1,4 +1,5 @@ import contextlib +import inspect from collections.abc import Iterable from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -73,6 +74,43 @@ def steal_forward( ` :param submodules: if True and model is a module, the list extended with all the submodules the module contains + + The following examples shows how to steal and dump all the inputs / outputs + for a module and its submodules, then restores them. + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.helpers.torch_test_helper import steal_forward + + class SubModel(torch.nn.Module): + def forward(self, x): + return x * x + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.s1 = SubModel() + self.s2 = SubModel() + + def forward(self, x, y): + return self.s1(x) + self.s2(y) + + inputs = torch.rand(2, 1), torch.rand(2, 1) + model = Model() + dump_file = "dump_steal_forward_submodules.onnx" + with steal_forward(model, submodules=True, dump_file=dump_file): + model(*inputs) + + # Let's restore the stolen data. + restored = create_input_tensors_from_onnx_model(dump_file) + for k, v in sorted(restored.items()): + if isinstance(v, tuple): + args, kwargs = v + print("input", k, args, kwargs) + else: + print("output", k, v) """ assert not submodules or isinstance( model, torch.nn.Module @@ -87,7 +125,9 @@ def steal_forward( for idx, m in model.named_modules(): level = str(idx).split(".") ll = len(level) - models.append((f"{' ' * ll}{idx}", m)) + _, start_line = inspect.getsourcelines(m.forward) + name = f"{idx}-{m.__class__.__name__}-{start_line}" + models.append((f"{' ' * ll}{name}", m)) model = models else: model = [model]