Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Change Logs
0.5.0
+++++

* :pr:`88`: extends ``steal_forward`` to dump input, outputs in onnx models
* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)

0.4.4
Expand Down
40 changes: 40 additions & 0 deletions _unittests/ut_helpers/test_torch_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,46 @@ def forward(self, x, y):
self.assertEqualAny(restored["main", 0, "O"], res1)
self.assertEqualAny(restored["main", 0, "O"], res2)

@hide_stdout()
def test_steal_forward_submodules(self):
class SubModel(torch.nn.Module):
def forward(self, x):
return x * x

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.s1 = SubModel()
self.s2 = SubModel()

def forward(self, x, y):
return self.s1(x) + self.s2(y)

inputs = torch.rand(3, 4), torch.rand(3, 4)
model = Model()
dump_file = self.get_dump_file("test_steal_forward_submodules.onnx")
with steal_forward(model, submodules=True, dump_file=dump_file):
model(*inputs)
restored = create_input_tensors_from_onnx_model(dump_file)
for k, v in sorted(restored.items()):
if isinstance(v, tuple):
args, kwargs = v
print("input", k, args, kwargs)
else:
print("output", k, v)
print(string_type(restored, with_shape=True))
self.assertEqual(
[
("-Model-159", 0, "I"),
("-Model-159", 0, "O"),
("s1-SubModel-150", 0, "I"),
("s1-SubModel-150", 0, "O"),
("s2-SubModel-150", 0, "I"),
("s2-SubModel-150", 0, "O"),
],
sorted(restored),
)

def test_replace_string_by_dynamic(self):
example = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
Expand Down
57 changes: 56 additions & 1 deletion onnx_diagnostic/helpers/torch_test_helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import inspect
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
Expand Down Expand Up @@ -54,6 +55,7 @@ def steal_forward(
],
fprint: Callable = string_type,
dump_file: Optional[str] = None,
submodules: bool = False,
**kwargs,
):
"""
Expand All @@ -70,12 +72,65 @@ def steal_forward(
:param dump_file: dumps stolen inputs and outputs in an onnx model,
they can be restored with :func:`create_input_tensors_from_onnx_model
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
:param submodules: if True and model is a module, the list extended with all the submodules
the module contains

The following examples shows how to steal and dump all the inputs / outputs
for a module and its submodules, then restores them.

.. runpython::
:showcode:

import torch
from onnx_diagnostic.helpers.torch_test_helper import steal_forward

class SubModel(torch.nn.Module):
def forward(self, x):
return x * x

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.s1 = SubModel()
self.s2 = SubModel()

def forward(self, x, y):
return self.s1(x) + self.s2(y)

inputs = torch.rand(2, 1), torch.rand(2, 1)
model = Model()
dump_file = "dump_steal_forward_submodules.onnx"
with steal_forward(model, submodules=True, dump_file=dump_file):
model(*inputs)

# Let's restore the stolen data.
restored = create_input_tensors_from_onnx_model(dump_file)
for k, v in sorted(restored.items()):
if isinstance(v, tuple):
args, kwargs = v
print("input", k, args, kwargs)
else:
print("output", k, v)
"""
assert not submodules or isinstance(
model, torch.nn.Module
), f"submodules can only be True if model is a module but is is {type(model)}."
context = dict(iteration=0, **kwargs)
if "with_shape" not in context and fprint == string_type:
context["with_shape"] = True
if not isinstance(model, list):
model = [model]
assert isinstance(model, torch.nn.Module), f"Unexpected type {type(model)} for model"
if submodules:
models = []
for idx, m in model.named_modules():
level = str(idx).split(".")
ll = len(level)
_, start_line = inspect.getsourcelines(m.forward)
name = f"{idx}-{m.__class__.__name__}-{start_line}"
models.append((f"{' ' * ll}{name}", m))
model = models
else:
model = [model]
keep_model_forward = {}
storage: Optional[Dict[Any, Any]] = {} if dump_file else None
for mt in model:
Expand Down
Loading