Skip to content

Commit 89c6ee7

Browse files
committed
update
1 parent 78ee08c commit 89c6ee7

File tree

2 files changed

+68
-19
lines changed

2 files changed

+68
-19
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,35 @@ def forward(self, x, y):
5959
with steal_forward(model):
6060
model(*inputs)
6161

62+
@hide_stdout()
63+
def test_steal_forward_multi(self):
64+
class SubModel(torch.nn.Module):
65+
def forward(self, x):
66+
return x * x
67+
68+
class Model(torch.nn.Module):
69+
def __init__(self):
70+
super().__init__()
71+
self.s1 = SubModel()
72+
self.s2 = SubModel()
73+
74+
def forward(self, x, y):
75+
return self.s1(x) + self.s2(y)
76+
77+
inputs = torch.rand(3, 4), torch.rand(3, 4)
78+
model = Model()
79+
with steal_forward(
80+
[
81+
(
82+
"main",
83+
model,
84+
),
85+
(" s1", model.s1),
86+
(" s2", model.s2),
87+
]
88+
):
89+
model(*inputs)
90+
6291
def test_replace_string_by_dynamic(self):
6392
example = {
6493
"input_ids": {0: "batch_size", 1: "sequence_length"},

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
from collections.abc import Iterable
3-
from typing import Any, Optional, Tuple, Union
3+
from typing import Any, Callable, List, Optional, Tuple, Union
44
import numpy as np
55
import torch
66
from .helper import string_type
@@ -12,11 +12,13 @@
1212
)
1313

1414

15-
def _forward_(*args, _f=None, _context=None, **kwargs):
15+
def _forward_(*args, _f=None, _fprint=string_type, _prefix="", _context=None, **kwargs):
1616
assert _f is not None, "_f cannot be None"
1717
assert _context is not None, "_context cannot be None"
18+
indent = " " * (len(_prefix) - len(_prefix.lstrip()))
19+
_prefix = _prefix.lstrip()
1820
print(
19-
f"---- stolen forward for class {_context['class_name']} "
21+
f"{indent}+{_prefix} -- stolen forward for class {_context['class_name']} "
2022
f"-- iteration {_context['iteration']}"
2123
)
2224
kws = dict(
@@ -25,36 +27,54 @@ def _forward_(*args, _f=None, _context=None, **kwargs):
2527
)
2628
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
2729
# torch.compiler.is_exporting requires torch>=2.7
28-
print(f" <- args={string_type(args, **kws)} --- kwargs={string_type(kwargs, **kws)}")
30+
print(f"{indent} <- args={_fprint(args, **kws)} --- kwargs={_fprint(kwargs, **kws)}")
2931
res = _f(*args, **kwargs)
3032
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
31-
print(" --")
32-
print(f" -> {string_type(res, **kws)}")
33-
print(".")
33+
print(f"{indent} -> {_fprint(res, **kws)}")
34+
print(f"{indent}-{_prefix}.")
3435
_context["iteration"] += 1
3536
return res
3637

3738

3839
@contextlib.contextmanager
39-
def steal_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False):
40+
def steal_forward(
41+
model: Union[
42+
Union[torch.nn.Module, Tuple[str, torch.nn.Module]],
43+
List[Union[torch.nn.Module, Tuple[str, torch.nn.Module]]],
44+
],
45+
fprint: Callable = string_type,
46+
**kwargs,
47+
):
4048
"""
4149
The necessary modification to steem forward method and prints out inputs
4250
and outputs. See example :ref:`l-plot-tiny-llm-export`.
51+
52+
:param model: a model or a list of models to monitor,
53+
every model can also be a tuple(name, model), name is displayed well.
54+
:param fprint: function used to print out (or dump), by default, it is
55+
:func:`onnx_diagnostic.helpers.string_type`
56+
:param kwargs: additional parameters sent to :func:`onnx_diagnostic.helpers.string_type`
57+
or any other function defined by ``fprint``
4358
"""
44-
context = dict(
45-
iteration=0,
46-
class_name=model.__class__.__name__,
47-
with_shape=with_shape,
48-
with_min_max=with_min_max,
49-
)
50-
keep_model_forward = model.forward
51-
model.forward = lambda *args, _f=keep_model_forward, _context=context, **kwargs: _forward_(
52-
*args, _f=_f, _context=_context, **kwargs
53-
)
59+
context = dict(iteration=0, **kwargs)
60+
if "with_shape" not in context and fprint == string_type:
61+
context["with_shape"] = True
62+
if not isinstance(model, list):
63+
model = [model]
64+
keep_model_forward = {}
65+
for mt in model:
66+
name, m = mt if isinstance(mt, tuple) else ("", mt)
67+
keep_model_forward[id(m)] = (m, m.forward)
68+
c = context.copy()
69+
c["class_name"] = m.__class__.__name__
70+
m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, **kws: _forward_(
71+
*args, _f=_f, _fprint=_fp, _context=_c, _prefix=_p, **kws
72+
)
5473
try:
5574
yield
5675
finally:
57-
model.forward = keep_model_forward
76+
for f in keep_model_forward.values():
77+
f[0].forward = f[1]
5878

5979

6080
def is_torchdynamo_exporting() -> bool:

0 commit comments

Comments
 (0)