Skip to content

Commit 8c7302d

Browse files
authored
Steal from submoduels as well (#89)
* steal from submodules * changr * mypy * line
1 parent 2444381 commit 8c7302d

File tree

3 files changed

+97
-2
lines changed

3 files changed

+97
-2
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Change Logs
44
0.5.0
55
+++++
66

7-
* :pr:`88`: extends ``steal_forward`` to dump input, outputs in onnx models
7+
* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models
88
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)
99

1010
0.4.4

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,46 @@ def forward(self, x, y):
144144
self.assertEqualAny(restored["main", 0, "O"], res1)
145145
self.assertEqualAny(restored["main", 0, "O"], res2)
146146

147+
@hide_stdout()
148+
def test_steal_forward_submodules(self):
149+
class SubModel(torch.nn.Module):
150+
def forward(self, x):
151+
return x * x
152+
153+
class Model(torch.nn.Module):
154+
def __init__(self):
155+
super().__init__()
156+
self.s1 = SubModel()
157+
self.s2 = SubModel()
158+
159+
def forward(self, x, y):
160+
return self.s1(x) + self.s2(y)
161+
162+
inputs = torch.rand(3, 4), torch.rand(3, 4)
163+
model = Model()
164+
dump_file = self.get_dump_file("test_steal_forward_submodules.onnx")
165+
with steal_forward(model, submodules=True, dump_file=dump_file):
166+
model(*inputs)
167+
restored = create_input_tensors_from_onnx_model(dump_file)
168+
for k, v in sorted(restored.items()):
169+
if isinstance(v, tuple):
170+
args, kwargs = v
171+
print("input", k, args, kwargs)
172+
else:
173+
print("output", k, v)
174+
print(string_type(restored, with_shape=True))
175+
self.assertEqual(
176+
[
177+
("-Model-159", 0, "I"),
178+
("-Model-159", 0, "O"),
179+
("s1-SubModel-150", 0, "I"),
180+
("s1-SubModel-150", 0, "O"),
181+
("s2-SubModel-150", 0, "I"),
182+
("s2-SubModel-150", 0, "O"),
183+
],
184+
sorted(restored),
185+
)
186+
147187
def test_replace_string_by_dynamic(self):
148188
example = {
149189
"input_ids": {0: "batch_size", 1: "sequence_length"},

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import inspect
23
from collections.abc import Iterable
34
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
45
import numpy as np
@@ -54,6 +55,7 @@ def steal_forward(
5455
],
5556
fprint: Callable = string_type,
5657
dump_file: Optional[str] = None,
58+
submodules: bool = False,
5759
**kwargs,
5860
):
5961
"""
@@ -70,12 +72,65 @@ def steal_forward(
7072
:param dump_file: dumps stolen inputs and outputs in an onnx model,
7173
they can be restored with :func:`create_input_tensors_from_onnx_model
7274
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
75+
:param submodules: if True and model is a module, the list extended with all the submodules
76+
the module contains
77+
78+
The following examples shows how to steal and dump all the inputs / outputs
79+
for a module and its submodules, then restores them.
80+
81+
.. runpython::
82+
:showcode:
83+
84+
import torch
85+
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
86+
87+
class SubModel(torch.nn.Module):
88+
def forward(self, x):
89+
return x * x
90+
91+
class Model(torch.nn.Module):
92+
def __init__(self):
93+
super().__init__()
94+
self.s1 = SubModel()
95+
self.s2 = SubModel()
96+
97+
def forward(self, x, y):
98+
return self.s1(x) + self.s2(y)
99+
100+
inputs = torch.rand(2, 1), torch.rand(2, 1)
101+
model = Model()
102+
dump_file = "dump_steal_forward_submodules.onnx"
103+
with steal_forward(model, submodules=True, dump_file=dump_file):
104+
model(*inputs)
105+
106+
# Let's restore the stolen data.
107+
restored = create_input_tensors_from_onnx_model(dump_file)
108+
for k, v in sorted(restored.items()):
109+
if isinstance(v, tuple):
110+
args, kwargs = v
111+
print("input", k, args, kwargs)
112+
else:
113+
print("output", k, v)
73114
"""
115+
assert not submodules or isinstance(
116+
model, torch.nn.Module
117+
), f"submodules can only be True if model is a module but is is {type(model)}."
74118
context = dict(iteration=0, **kwargs)
75119
if "with_shape" not in context and fprint == string_type:
76120
context["with_shape"] = True
77121
if not isinstance(model, list):
78-
model = [model]
122+
assert isinstance(model, torch.nn.Module), f"Unexpected type {type(model)} for model"
123+
if submodules:
124+
models = []
125+
for idx, m in model.named_modules():
126+
level = str(idx).split(".")
127+
ll = len(level)
128+
_, start_line = inspect.getsourcelines(m.forward)
129+
name = f"{idx}-{m.__class__.__name__}-{start_line}"
130+
models.append((f"{' ' * ll}{name}", m))
131+
model = models
132+
else:
133+
model = [model]
79134
keep_model_forward = {}
80135
storage: Optional[Dict[Any, Any]] = {} if dump_file else None
81136
for mt in model:

0 commit comments

Comments
 (0)