Skip to content

Commit 35fe955

Browse files
committed
steal from submodules
1 parent 2444381 commit 35fe955

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,39 @@ def forward(self, x, y):
144144
self.assertEqualAny(restored["main", 0, "O"], res1)
145145
self.assertEqualAny(restored["main", 0, "O"], res2)
146146

147+
@hide_stdout()
148+
def test_steal_forward_submodules(self):
149+
class SubModel(torch.nn.Module):
150+
def forward(self, x):
151+
return x * x
152+
153+
class Model(torch.nn.Module):
154+
def __init__(self):
155+
super().__init__()
156+
self.s1 = SubModel()
157+
self.s2 = SubModel()
158+
159+
def forward(self, x, y):
160+
return self.s1(x) + self.s2(y)
161+
162+
inputs = torch.rand(3, 4), torch.rand(3, 4)
163+
model = Model()
164+
dump_file = self.get_dump_file("test_steal_forward_submodules.onnx")
165+
with steal_forward(model, submodules=True, dump_file=dump_file):
166+
model(*inputs)
167+
restored = create_input_tensors_from_onnx_model(dump_file)
168+
self.assertEqual(
169+
[
170+
("", 0, "I"),
171+
("", 0, "O"),
172+
("s1", 0, "I"),
173+
("s1", 0, "O"),
174+
("s2", 0, "I"),
175+
("s2", 0, "O"),
176+
],
177+
sorted(restored),
178+
)
179+
147180
def test_replace_string_by_dynamic(self):
148181
example = {
149182
"input_ids": {0: "batch_size", 1: "sequence_length"},

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def steal_forward(
5454
],
5555
fprint: Callable = string_type,
5656
dump_file: Optional[str] = None,
57+
submodules: bool = False,
5758
**kwargs,
5859
):
5960
"""
@@ -70,12 +71,25 @@ def steal_forward(
7071
:param dump_file: dumps stolen inputs and outputs in an onnx model,
7172
they can be restored with :func:`create_input_tensors_from_onnx_model
7273
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
74+
:param submodules: if True and model is a module, the list extended with all the submodules
75+
the module contains
7376
"""
77+
assert not submodules or isinstance(
78+
model, torch.nn.Module
79+
), f"submodules can only be True if model is a module but is is {type(model)}."
7480
context = dict(iteration=0, **kwargs)
7581
if "with_shape" not in context and fprint == string_type:
7682
context["with_shape"] = True
7783
if not isinstance(model, list):
78-
model = [model]
84+
if submodules:
85+
models = []
86+
for idx, m in model.named_modules():
87+
level = str(idx).split(".")
88+
ll = len(level)
89+
models.append((f"{' ' * ll}{idx}", m))
90+
model = models
91+
else:
92+
model = [model]
7993
keep_model_forward = {}
8094
storage: Optional[Dict[Any, Any]] = {} if dump_file else None
8195
for mt in model:

0 commit comments

Comments
 (0)