11import contextlib
2+ import inspect
23from collections .abc import Iterable
34from typing import Any , Callable , Dict , List , Optional , Tuple , Union
45import numpy as np
@@ -54,6 +55,7 @@ def steal_forward(
5455 ],
5556 fprint : Callable = string_type ,
5657 dump_file : Optional [str ] = None ,
58+ submodules : bool = False ,
5759 ** kwargs ,
5860):
5961 """
@@ -70,12 +72,65 @@ def steal_forward(
7072 :param dump_file: dumps stolen inputs and outputs in an onnx model,
7173 they can be restored with :func:`create_input_tensors_from_onnx_model
7274 <onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
75+ :param submodules: if True and model is a module, the list extended with all the submodules
76+ the module contains
77+
78+ The following examples shows how to steal and dump all the inputs / outputs
79+ for a module and its submodules, then restores them.
80+
81+ .. runpython::
82+ :showcode:
83+
84+ import torch
85+ from onnx_diagnostic.helpers.torch_test_helper import steal_forward
86+
87+ class SubModel(torch.nn.Module):
88+ def forward(self, x):
89+ return x * x
90+
91+ class Model(torch.nn.Module):
92+ def __init__(self):
93+ super().__init__()
94+ self.s1 = SubModel()
95+ self.s2 = SubModel()
96+
97+ def forward(self, x, y):
98+ return self.s1(x) + self.s2(y)
99+
100+ inputs = torch.rand(2, 1), torch.rand(2, 1)
101+ model = Model()
102+ dump_file = "dump_steal_forward_submodules.onnx"
103+ with steal_forward(model, submodules=True, dump_file=dump_file):
104+ model(*inputs)
105+
106+ # Let's restore the stolen data.
107+ restored = create_input_tensors_from_onnx_model(dump_file)
108+ for k, v in sorted(restored.items()):
109+ if isinstance(v, tuple):
110+ args, kwargs = v
111+ print("input", k, args, kwargs)
112+ else:
113+ print("output", k, v)
73114 """
115+ assert not submodules or isinstance (
116+ model , torch .nn .Module
117+ ), f"submodules can only be True if model is a module but is is { type (model )} ."
74118 context = dict (iteration = 0 , ** kwargs )
75119 if "with_shape" not in context and fprint == string_type :
76120 context ["with_shape" ] = True
77121 if not isinstance (model , list ):
78- model = [model ]
122+ assert isinstance (model , torch .nn .Module ), f"Unexpected type { type (model )} for model"
123+ if submodules :
124+ models = []
125+ for idx , m in model .named_modules ():
126+ level = str (idx ).split ("." )
127+ ll = len (level )
128+ _ , start_line = inspect .getsourcelines (m .forward )
129+ name = f"{ idx } -{ m .__class__ .__name__ } -{ start_line } "
130+ models .append ((f"{ ' ' * ll } { name } " , m ))
131+ model = models
132+ else :
133+ model = [model ]
79134 keep_model_forward = {}
80135 storage : Optional [Dict [Any , Any ]] = {} if dump_file else None
81136 for mt in model :
0 commit comments