11import contextlib
2+ import inspect
23from collections .abc import Iterable
34from typing import Any , Callable , Dict , List , Optional , Tuple , Union
45import numpy as np
@@ -73,6 +74,43 @@ def steal_forward(
7374 <onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
7475 :param submodules: if True and model is a module, the list extended with all the submodules
7576 the module contains
77+
78+ The following examples shows how to steal and dump all the inputs / outputs
79+ for a module and its submodules, then restores them.
80+
81+ .. runpython::
82+ :showcode:
83+
84+ import torch
85+ from onnx_diagnostic.helpers.torch_test_helper import steal_forward
86+
87+ class SubModel(torch.nn.Module):
88+ def forward(self, x):
89+ return x * x
90+
91+ class Model(torch.nn.Module):
92+ def __init__(self):
93+ super().__init__()
94+ self.s1 = SubModel()
95+ self.s2 = SubModel()
96+
97+ def forward(self, x, y):
98+ return self.s1(x) + self.s2(y)
99+
100+ inputs = torch.rand(2, 1), torch.rand(2, 1)
101+ model = Model()
102+ dump_file = "dump_steal_forward_submodules.onnx"
103+ with steal_forward(model, submodules=True, dump_file=dump_file):
104+ model(*inputs)
105+
106+ # Let's restore the stolen data.
107+ restored = create_input_tensors_from_onnx_model(dump_file)
108+ for k, v in sorted(restored.items()):
109+ if isinstance(v, tuple):
110+ args, kwargs = v
111+ print("input", k, args, kwargs)
112+ else:
113+ print("output", k, v)
76114 """
77115 assert not submodules or isinstance (
78116 model , torch .nn .Module
@@ -87,7 +125,9 @@ def steal_forward(
87125 for idx , m in model .named_modules ():
88126 level = str (idx ).split ("." )
89127 ll = len (level )
90- models .append ((f"{ ' ' * ll } { idx } " , m ))
128+ _ , start_line = inspect .getsourcelines (m .forward )
129+ name = f"{ idx } -{ m .__class__ .__name__ } -{ start_line } "
130+ models .append ((f"{ ' ' * ll } { name } " , m ))
91131 model = models
92132 else :
93133 model = [model ]
0 commit comments