Skip to content

Commit 1a6e874

Browse files
committed
line
1 parent eb75c75 commit 1a6e874

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,21 @@ def forward(self, x, y):
165165
with steal_forward(model, submodules=True, dump_file=dump_file):
166166
model(*inputs)
167167
restored = create_input_tensors_from_onnx_model(dump_file)
168+
for k, v in sorted(restored.items()):
169+
if isinstance(v, tuple):
170+
args, kwargs = v
171+
print("input", k, args, kwargs)
172+
else:
173+
print("output", k, v)
174+
print(string_type(restored, with_shape=True))
168175
self.assertEqual(
169176
[
170-
("", 0, "I"),
171-
("", 0, "O"),
172-
("s1", 0, "I"),
173-
("s1", 0, "O"),
174-
("s2", 0, "I"),
175-
("s2", 0, "O"),
177+
("-Model-159", 0, "I"),
178+
("-Model-159", 0, "O"),
179+
("s1-SubModel-150", 0, "I"),
180+
("s1-SubModel-150", 0, "O"),
181+
("s2-SubModel-150", 0, "I"),
182+
("s2-SubModel-150", 0, "O"),
176183
],
177184
sorted(restored),
178185
)

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import inspect
23
from collections.abc import Iterable
34
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
45
import numpy as np
@@ -73,6 +74,43 @@ def steal_forward(
7374
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
7475
:param submodules: if True and model is a module, the list extended with all the submodules
7576
the module contains
77+
78+
The following examples shows how to steal and dump all the inputs / outputs
79+
for a module and its submodules, then restores them.
80+
81+
.. runpython::
82+
:showcode:
83+
84+
import torch
85+
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
86+
87+
class SubModel(torch.nn.Module):
88+
def forward(self, x):
89+
return x * x
90+
91+
class Model(torch.nn.Module):
92+
def __init__(self):
93+
super().__init__()
94+
self.s1 = SubModel()
95+
self.s2 = SubModel()
96+
97+
def forward(self, x, y):
98+
return self.s1(x) + self.s2(y)
99+
100+
inputs = torch.rand(2, 1), torch.rand(2, 1)
101+
model = Model()
102+
dump_file = "dump_steal_forward_submodules.onnx"
103+
with steal_forward(model, submodules=True, dump_file=dump_file):
104+
model(*inputs)
105+
106+
# Let's restore the stolen data.
107+
restored = create_input_tensors_from_onnx_model(dump_file)
108+
for k, v in sorted(restored.items()):
109+
if isinstance(v, tuple):
110+
args, kwargs = v
111+
print("input", k, args, kwargs)
112+
else:
113+
print("output", k, v)
76114
"""
77115
assert not submodules or isinstance(
78116
model, torch.nn.Module
@@ -87,7 +125,9 @@ def steal_forward(
87125
for idx, m in model.named_modules():
88126
level = str(idx).split(".")
89127
ll = len(level)
90-
models.append((f"{' ' * ll}{idx}", m))
128+
_, start_line = inspect.getsourcelines(m.forward)
129+
name = f"{idx}-{m.__class__.__name__}-{start_line}"
130+
models.append((f"{' ' * ll}{name}", m))
91131
model = models
92132
else:
93133
model = [model]

0 commit comments

Comments
 (0)